Skip to content

Commit 78ac17d

Browse files
refactor: GainMatrixUpdater uses AnyMutableTrackState (#4890)
Blocked by: - #4889
1 parent 7e97ac8 commit 78ac17d

File tree

8 files changed

+162
-260
lines changed

8 files changed

+162
-260
lines changed

Core/include/Acts/TrackFitting/GainMatrixSmoother.hpp

Lines changed: 6 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#pragma once
1010

11+
#include "Acts/EventData/AnyTrackStateProxy.hpp"
1112
#include "Acts/EventData/MultiTrajectory.hpp"
1213
#include "Acts/Geometry/GeometryContext.hpp"
1314
#include "Acts/Utilities/Delegate.hpp"
@@ -45,41 +46,6 @@ class GainMatrixSmoother {
4546
const Logger& logger = getDummyLogger()) const {
4647
(void)gctx;
4748

48-
using TrackStateProxy = typename traj_t::TrackStateProxy;
49-
50-
GetParameters filtered;
51-
GetCovariance filteredCovariance;
52-
GetParameters smoothed;
53-
GetParameters predicted;
54-
GetCovariance predictedCovariance;
55-
GetCovariance smoothedCovariance;
56-
GetCovariance jacobian;
57-
58-
filtered.connect([](const void*, void* ts) {
59-
return static_cast<TrackStateProxy*>(ts)->filtered();
60-
});
61-
filteredCovariance.connect([](const void*, void* ts) {
62-
return static_cast<TrackStateProxy*>(ts)->filteredCovariance();
63-
});
64-
65-
smoothed.connect([](const void*, void* ts) {
66-
return static_cast<TrackStateProxy*>(ts)->smoothed();
67-
});
68-
smoothedCovariance.connect([](const void*, void* ts) {
69-
return static_cast<TrackStateProxy*>(ts)->smoothedCovariance();
70-
});
71-
72-
predicted.connect([](const void*, void* ts) {
73-
return static_cast<TrackStateProxy*>(ts)->predicted();
74-
});
75-
predictedCovariance.connect([](const void*, void* ts) {
76-
return static_cast<TrackStateProxy*>(ts)->predictedCovariance();
77-
});
78-
79-
jacobian.connect([](const void*, void* ts) {
80-
return static_cast<TrackStateProxy*>(ts)->jacobian();
81-
});
82-
8349
ACTS_VERBOSE("Invoked GainMatrixSmoother on entry index: " << entryIndex);
8450

8551
// For the last state: smoothed is filtered - also: switch to next
@@ -118,9 +84,8 @@ class GainMatrixSmoother {
11884
// ensure the track state has a smoothed component
11985
ts.addComponents(TrackStatePropMask::Smoothed);
12086

121-
if (auto res = calculate(&ts, &prev_ts, filtered, filteredCovariance,
122-
smoothed, predicted, predictedCovariance,
123-
smoothedCovariance, jacobian, logger);
87+
if (auto res = calculate(AnyMutableTrackStateProxy{ts},
88+
AnyConstTrackStateProxy{prev_ts}, logger);
12489
!res.ok()) {
12590
error = res.error();
12691
return false;
@@ -146,23 +111,11 @@ class GainMatrixSmoother {
146111
/// formalism.
147112
///
148113
/// @param ts Current track state to be smoothed
149-
/// @param prev_ts Previous track state (already smoothed)
150-
/// @param filtered Delegate to get filtered parameters
151-
/// @param filteredCovariance Delegate to get filtered covariance
152-
/// @param smoothed Delegate to get smoothed parameters
153-
/// @param predicted Delegate to get predicted parameters
154-
/// @param predictedCovariance Delegate to get predicted covariance
155-
/// @param smoothedCovariance Delegate to get smoothed covariance
156-
/// @param jacobian Delegate to get Jacobian matrix
114+
/// @param prev_ts Previous track state (in forward direction)
157115
/// @param logger Logger for verbose output
158116
/// @return Success or failure of the smoothing calculation
159-
Result<void> calculate(void* ts, void* prev_ts, const GetParameters& filtered,
160-
const GetCovariance& filteredCovariance,
161-
const GetParameters& smoothed,
162-
const GetParameters& predicted,
163-
const GetCovariance& predictedCovariance,
164-
const GetCovariance& smoothedCovariance,
165-
const GetCovariance& jacobian,
117+
Result<void> calculate(AnyMutableTrackStateProxy ts,
118+
AnyConstTrackStateProxy prev_ts,
166119
const Logger& logger) const;
167120
};
168121

Core/include/Acts/TrackFitting/GainMatrixUpdater.hpp

Lines changed: 5 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
#pragma once
1010

11-
#include "Acts/EventData/MultiTrajectory.hpp"
11+
#include "Acts/EventData/AnyTrackStateProxy.hpp"
1212
#include "Acts/EventData/Types.hpp"
1313
#include "Acts/Geometry/GeometryContext.hpp"
1414
#include "Acts/Utilities/Logger.hpp"
@@ -23,20 +23,6 @@ namespace Acts {
2323
/// Kalman update step using the gain matrix formalism.
2424
/// @ingroup track_fitting
2525
class GainMatrixUpdater {
26-
struct InternalTrackState {
27-
unsigned int calibratedSize;
28-
// This is used to build a covariance matrix view in the .cpp file
29-
const double* calibrated;
30-
const double* calibratedCovariance;
31-
BoundSubspaceIndices projector;
32-
33-
TrackStateTraits<kMeasurementSizeMax, false>::Parameters predicted;
34-
TrackStateTraits<kMeasurementSizeMax, false>::Covariance
35-
predictedCovariance;
36-
TrackStateTraits<kMeasurementSizeMax, false>::Parameters filtered;
37-
TrackStateTraits<kMeasurementSizeMax, false>::Covariance filteredCovariance;
38-
};
39-
4026
public:
4127
/// Run the Kalman update step for a single trajectory state.
4228
///
@@ -70,20 +56,8 @@ class GainMatrixUpdater {
7056
// auto filtered = trackState.filtered();
7157
// auto filteredCovariance = trackState.filteredCovariance();
7258

73-
auto [chi2, error] = visitMeasurement(
74-
InternalTrackState{
75-
trackState.calibratedSize(),
76-
// Note that we pass raw pointers here which are used in the correct
77-
// shape later
78-
trackState.effectiveCalibrated().data(),
79-
trackState.effectiveCalibratedCovariance().data(),
80-
trackState.projectorSubspaceIndices(),
81-
trackState.predicted(),
82-
trackState.predictedCovariance(),
83-
trackState.filtered(),
84-
trackState.filteredCovariance(),
85-
},
86-
logger);
59+
auto [chi2, error] =
60+
visitMeasurement(AnyMutableTrackStateProxy{trackState}, logger);
8761

8862
trackState.chi2() = chi2;
8963

@@ -92,11 +66,11 @@ class GainMatrixUpdater {
9266

9367
private:
9468
std::tuple<double, std::error_code> visitMeasurement(
95-
InternalTrackState trackState, const Logger& logger) const;
69+
AnyMutableTrackStateProxy trackState, const Logger& logger) const;
9670

9771
template <std::size_t N>
9872
std::tuple<double, std::error_code> visitMeasurementImpl(
99-
InternalTrackState trackState, const Logger& logger) const;
73+
AnyMutableTrackStateProxy trackState, const Logger& logger) const;
10074
};
10175

10276
} // namespace Acts

Core/include/Acts/TrackFitting/MbfSmoother.hpp

Lines changed: 13 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@
99
#pragma once
1010

1111
#include "Acts/Definitions/TrackParametrization.hpp"
12+
#include "Acts/EventData/AnyTrackStateProxy.hpp"
1213
#include "Acts/EventData/MultiTrajectory.hpp"
1314
#include "Acts/Geometry/GeometryContext.hpp"
1415
#include "Acts/Utilities/Logger.hpp"
1516
#include "Acts/Utilities/Result.hpp"
1617

1718
#include <cstddef>
1819
#include <optional>
20+
#include <utility>
1921

2022
namespace Acts {
2123

@@ -65,7 +67,7 @@ class MbfSmoother {
6567
// ensure the track state has a smoothed component
6668
ts.addComponents(TrackStatePropMask::Smoothed);
6769

68-
InternalTrackState internalTrackState(ts);
70+
AnyMutableTrackStateProxy internalTrackState(ts);
6971

7072
// Smoothe the current state
7173
calculateSmoothed(internalTrackState, bigLambdaHat, smallLambdaHat);
@@ -77,78 +79,31 @@ class MbfSmoother {
7779

7880
// Update the lambdas depending on the type of track state
7981
if (ts.typeFlags().test(TrackStateFlag::MeasurementFlag)) {
80-
visitMeasurement(internalTrackState, bigLambdaHat, smallLambdaHat);
82+
visitMeasurement(AnyConstTrackStateProxy{ts}, bigLambdaHat,
83+
smallLambdaHat);
8184
} else {
82-
visitNonMeasurement(internalTrackState, bigLambdaHat, smallLambdaHat);
85+
visitNonMeasurement(std::as_const(ts).jacobian(), bigLambdaHat,
86+
smallLambdaHat);
8387
}
8488
});
8589

8690
return Result<void>::success();
8791
}
8892

8993
private:
90-
/// Internal track state representation for the smoother.
91-
/// @note This allows us to move parts of the implementation into the .cpp
92-
struct InternalTrackState final {
93-
using Jacobian =
94-
typename TrackStateTraits<kMeasurementSizeMax, false>::Covariance;
95-
using Parameters =
96-
typename TrackStateTraits<kMeasurementSizeMax, false>::Parameters;
97-
using Covariance =
98-
typename TrackStateTraits<kMeasurementSizeMax, false>::Covariance;
99-
100-
struct Measurement final {
101-
unsigned int calibratedSize{0};
102-
// This is used to build a covariance matrix view in the .cpp file
103-
const double* calibrated{nullptr};
104-
const double* calibratedCovariance{nullptr};
105-
BoundSubspaceIndices projector;
106-
107-
template <typename TrackStateProxy>
108-
explicit Measurement(TrackStateProxy ts)
109-
: calibratedSize(ts.calibratedSize()),
110-
calibrated(ts.effectiveCalibrated().data()),
111-
calibratedCovariance(ts.effectiveCalibratedCovariance().data()),
112-
projector(ts.projectorSubspaceIndices()) {}
113-
};
114-
115-
Jacobian jacobian;
116-
117-
Parameters predicted;
118-
Covariance predictedCovariance;
119-
Parameters filtered;
120-
Covariance filteredCovariance;
121-
Parameters smoothed;
122-
Covariance smoothedCovariance;
123-
124-
std::optional<Measurement> measurement;
125-
126-
template <typename TrackStateProxy>
127-
explicit InternalTrackState(TrackStateProxy ts)
128-
: jacobian(ts.jacobian()),
129-
predicted(ts.predicted()),
130-
predictedCovariance(ts.predictedCovariance()),
131-
filtered(ts.filtered()),
132-
filteredCovariance(ts.filteredCovariance()),
133-
smoothed(ts.smoothed()),
134-
smoothedCovariance(ts.smoothedCovariance()),
135-
measurement(ts.typeFlags().test(TrackStateFlag::MeasurementFlag)
136-
? std::optional<Measurement>(ts)
137-
: std::nullopt) {}
138-
};
139-
14094
/// Calculate the smoothed parameters and covariance.
141-
void calculateSmoothed(InternalTrackState& ts,
95+
void calculateSmoothed(AnyMutableTrackStateProxy& ts,
14296
const BoundMatrix& bigLambdaHat,
14397
const BoundVector& smallLambdaHat) const;
14498

14599
/// Visit a non-measurement track state and update the lambdas.
146-
void visitNonMeasurement(const InternalTrackState& ts,
147-
BoundMatrix& bigLambdaHat,
148-
BoundVector& smallLambdaHat) const;
100+
void visitNonMeasurement(
101+
const AnyConstTrackStateProxy::ConstCovarianceMap& jacobian,
102+
BoundMatrix& bigLambdaHat, BoundVector& smallLambdaHat) const;
149103

150104
/// Visit a measurement track state and update the lambdas.
151-
void visitMeasurement(const InternalTrackState& ts, BoundMatrix& bigLambdaHat,
105+
void visitMeasurement(const AnyConstTrackStateProxy& ts,
106+
BoundMatrix& bigLambdaHat,
152107
BoundVector& smallLambdaHat) const;
153108
};
154109

Core/include/Acts/TrackFitting/detail/GainMatrixUpdaterImpl.hpp

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#pragma once
1010

11+
#include "Acts/EventData/AnyTrackStateProxy.hpp"
1112
#include "Acts/EventData/TrackParameterHelpers.hpp"
1213
#include "Acts/TrackFitting/GainMatrixUpdater.hpp"
1314
#include "Acts/TrackFitting/KalmanFitterError.hpp"
@@ -20,25 +21,24 @@ namespace Acts {
2021

2122
template <std::size_t N>
2223
std::tuple<double, std::error_code> GainMatrixUpdater::visitMeasurementImpl(
23-
InternalTrackState trackState, const Logger& logger) const {
24+
AnyMutableTrackStateProxy trackState, const Logger& logger) const {
2425
double chi2 = 0;
2526

2627
constexpr std::size_t kMeasurementSize = N;
2728
using ParametersVector = ActsVector<kMeasurementSize>;
2829
using CovarianceMatrix = ActsSquareMatrix<kMeasurementSize>;
2930

30-
typename TrackStateTraits<kMeasurementSize, true>::Calibrated calibrated{
31-
trackState.calibrated};
32-
typename TrackStateTraits<kMeasurementSize, true>::CalibratedCovariance
33-
calibratedCovariance{trackState.calibratedCovariance};
31+
auto calibrated = trackState.calibrated<kMeasurementSize>();
32+
auto calibratedCovariance =
33+
trackState.calibratedCovariance<kMeasurementSize>();
3434

3535
ACTS_VERBOSE("Measurement dimension: " << kMeasurementSize);
3636
ACTS_VERBOSE("Calibrated measurement: " << calibrated.transpose());
3737
ACTS_VERBOSE("Calibrated measurement covariance:\n" << calibratedCovariance);
3838

39-
std::span<const std::uint8_t, kMeasurementSize> validSubspaceIndices(
40-
trackState.projector.begin(),
41-
trackState.projector.begin() + kMeasurementSize);
39+
const auto validSubspaceIndices =
40+
trackState.template projectorSubspaceIndices<kMeasurementSize>();
41+
4242
FixedBoundSubspaceHelper<kMeasurementSize> subspaceHelper(
4343
validSubspaceIndices);
4444

@@ -47,11 +47,16 @@ std::tuple<double, std::error_code> GainMatrixUpdater::visitMeasurementImpl(
4747

4848
ACTS_VERBOSE("Measurement projector H:\n" << H);
4949

50-
const auto K = (trackState.predictedCovariance * H.transpose() *
51-
(H * trackState.predictedCovariance * H.transpose() +
52-
calibratedCovariance)
53-
.inverse())
54-
.eval();
50+
auto filtered = trackState.filtered();
51+
auto filteredCovariance = trackState.filteredCovariance();
52+
const auto predicted = trackState.predicted();
53+
const auto predictedCovariance = trackState.predictedCovariance();
54+
55+
const auto K =
56+
(predictedCovariance * H.transpose() *
57+
(H * predictedCovariance * H.transpose() + calibratedCovariance)
58+
.inverse())
59+
.eval();
5560

5661
ACTS_VERBOSE("Gain Matrix K:\n" << K);
5762

@@ -60,17 +65,16 @@ std::tuple<double, std::error_code> GainMatrixUpdater::visitMeasurementImpl(
6065
return {0, KalmanFitterError::UpdateFailed};
6166
}
6267

63-
trackState.filtered =
64-
trackState.predicted + K * (calibrated - H * trackState.predicted);
68+
filtered = predicted + K * (calibrated - H * predicted);
6569
// Normalize phi and theta
66-
trackState.filtered = normalizeBoundParameters(trackState.filtered);
67-
trackState.filteredCovariance =
68-
(BoundSquareMatrix::Identity() - K * H) * trackState.predictedCovariance;
69-
ACTS_VERBOSE("Filtered parameters: " << trackState.filtered.transpose());
70-
ACTS_VERBOSE("Filtered covariance:\n" << trackState.filteredCovariance);
70+
filtered = normalizeBoundParameters(filtered);
71+
filteredCovariance =
72+
(BoundSquareMatrix::Identity() - K * H) * predictedCovariance;
73+
ACTS_VERBOSE("Filtered parameters: " << filtered.transpose());
74+
ACTS_VERBOSE("Filtered covariance:\n" << filteredCovariance);
7175

7276
ParametersVector residual;
73-
residual = calibrated - H * trackState.filtered;
77+
residual = calibrated - H * filtered;
7478
ACTS_VERBOSE("Residual: " << residual.transpose());
7579

7680
CovarianceMatrix m =
@@ -85,10 +89,10 @@ std::tuple<double, std::error_code> GainMatrixUpdater::visitMeasurementImpl(
8589

8690
// Ensure thet the compiler does not implicitly instantiate the template
8791

88-
#define _EXTERN(N) \
89-
extern template std::tuple<double, std::error_code> \
90-
GainMatrixUpdater::visitMeasurementImpl<N>(InternalTrackState trackState, \
91-
const Logger& logger) const
92+
#define _EXTERN(N) \
93+
extern template std::tuple<double, std::error_code> \
94+
GainMatrixUpdater::visitMeasurementImpl<N>( \
95+
AnyMutableTrackStateProxy trackState, const Logger& logger) const
9296

9397
_EXTERN(1);
9498
_EXTERN(2);

0 commit comments

Comments
 (0)