Skip to content

Commit f5ca41d

Browse files
committed
Refined hole counting in KF
1 parent 83b1b6a commit f5ca41d

File tree

1 file changed

+123
-25
lines changed

1 file changed

+123
-25
lines changed

core/include/traccc/fitting/kalman_filter/kalman_actor.hpp

Lines changed: 123 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,16 @@
1919
#include "traccc/utils/particle.hpp"
2020

2121
// detray include(s).
22+
#include <detray/navigation/direct_navigator.hpp>
23+
#include <detray/navigation/navigator.hpp>
2224
#include <detray/propagator/base_actor.hpp>
2325

2426
// vecmem include(s)
2527
#include <vecmem/containers/device_vector.hpp>
2628

29+
// System include(s)
30+
#include <cstdlib>
31+
2732
namespace traccc {
2833

2934
enum class kalman_actor_direction {
@@ -65,6 +70,7 @@ struct kalman_actor_state {
6570
} else {
6671
m_idx = m_track.state_indices().size() - 1;
6772
}
73+
n_holes = 0u;
6874
}
6975

7076
/// Advance the iterator
@@ -80,11 +86,106 @@ struct kalman_actor_state {
8086
/// @return true if the iterator reaches the end of vector
8187
TRACCC_HOST_DEVICE
8288
bool is_complete() {
83-
if (!backward_mode && m_idx == m_track.state_indices().size()) {
84-
return true;
85-
} else if (backward_mode && m_idx > m_track.state_indices().size()) {
86-
return true;
89+
return ((!backward_mode && m_idx == m_track.state_indices().size()) ||
90+
(backward_mode && m_idx > m_track.state_indices().size()));
91+
}
92+
93+
/// @returns the current number of holes in this state
94+
TRACCC_HOST_DEVICE
95+
unsigned int count_missed() const {
96+
unsigned int n_missed{0u};
97+
98+
if (backward_mode) {
99+
for (const auto& trk_state : m_track_states) {
100+
if (trk_state.smoothed().is_invalid()) {
101+
++n_missed;
102+
}
103+
}
104+
} else {
105+
for (const auto& trk_state : m_track_states) {
106+
if (trk_state.filtered().is_invalid()) {
107+
++n_missed;
108+
}
109+
}
87110
}
111+
112+
return n_missed;
113+
}
114+
115+
template <typename nav_state_t>
116+
void evaluate_hole(const nav_state_t& navigation) {
117+
if (navigation.is_good_candidate()) {
118+
++n_holes;
119+
}
120+
}
121+
122+
/// @return true if the iterator reaches the end of vector
123+
/// @TODO: Remove once direct navigator is used in forward pass
124+
template <typename propagation_state_t>
125+
TRACCC_HOST_DEVICE bool check_matching_surface(
126+
propagation_state_t& propagation) {
127+
128+
auto& navigation = propagation._navigation;
129+
auto& trk_state = (*this)();
130+
131+
// Surface was found, continue with KF algorithm
132+
if (navigation.barcode() == trk_state.surface_link()) {
133+
// Count a hole, if track finding did not find a measurement
134+
if (trk_state.is_hole) {
135+
evaluate_hole(navigation);
136+
}
137+
// If track finding did not find measurement on this surface: skip
138+
return !trk_state.is_hole;
139+
}
140+
141+
// Skipped surfaces: adjust iterator and remove counted hole
142+
// (only relevant if using non-direct navigation, e.g. forward truth
143+
// fitting or different prop. config between CKF asnd KF)
144+
// TODO: Remove again
145+
using detector_t = typename propagation_state_t::detector_type;
146+
using nav_state_t = typename propagation_state_t::navigator_state_type;
147+
if constexpr (!std::same_as<nav_state_t,
148+
typename detray::direct_navigator<
149+
detector_t>::state>) {
150+
int i{1};
151+
if (backward_mode) {
152+
// If we are on the last state and the navigation surface does
153+
// not match, it must be an additional surface
154+
// -> continue navigation until matched
155+
if (m_it_rev + 1 == m_track_states.rend()) {
156+
evaluate_hole(navigation);
157+
return false;
158+
}
159+
// Check if the current navigation surfaces can be found on a
160+
// later track state. That means the current track state was
161+
// skipped by the navigator: Advance the internal iterator
162+
for (auto itr = m_it_rev + 1; itr != m_track_states.rend();
163+
++itr) {
164+
if (itr->surface_link() == navigation.barcode()) {
165+
m_it_rev += i;
166+
return true;
167+
}
168+
++i;
169+
}
170+
} else {
171+
if (m_it + 1 == m_track_states.end()) {
172+
evaluate_hole(navigation);
173+
return false;
174+
}
175+
for (auto itr = m_it + 1; itr != m_track_states.end(); ++itr) {
176+
if (itr->surface_link() == navigation.barcode()) {
177+
m_it += i;
178+
return true;
179+
}
180+
++i;
181+
}
182+
}
183+
}
184+
185+
// Mismatch was not from missed state: Is a hole
186+
++n_holes;
187+
188+
// After additional surface, keep navigating until match is found
88189
return false;
89190
}
90191

@@ -104,8 +205,7 @@ struct kalman_actor_state {
104205
/// Index of the current track state
105206
unsigned int m_idx;
106207

107-
// The number of holes (The number of sensitive surfaces which do not
108-
// have a measurement for the track pattern)
208+
// Count the number of encountered surfaces without measurement
109209
unsigned int n_holes{0u};
110210

111211
// Run back filtering for smoothing, if true
@@ -130,7 +230,6 @@ struct kalman_actor : detray::actor {
130230
auto& stepping = propagation._stepping;
131231
auto& navigation = propagation._navigation;
132232

133-
// If the iterator reaches the end, terminate the propagation
134233
if (actor_state.is_complete()) {
135234
propagation._heartbeat &= navigation.exit();
136235
return;
@@ -142,20 +241,15 @@ struct kalman_actor : detray::actor {
142241
typename edm::track_state_collection<algebra_t>::device::proxy_type
143242
trk_state = actor_state();
144243

145-
// Increase the hole counts if the propagator fails to find the next
146-
// measurement
147-
if (navigation.barcode() !=
148-
actor_state.m_measurements.at(trk_state.measurement_index())
149-
.surface_link) {
150-
if (!actor_state.backward_mode) {
151-
actor_state.n_holes++;
152-
}
153-
return;
154-
}
244+
// Did the navigation switch direction?
245+
actor_state.backward_mode =
246+
navigation.direction() ==
247+
detray::navigation::direction::e_backward;
155248

156-
// This track state is not a hole
157-
if (!actor_state.backward_mode) {
158-
trk_state.set_hole(false);
249+
// Increase the hole count if the propagator stops at an additional
250+
// surface and wait for the next sensitive surface to match
251+
if (!actor_state.check_matching_surface(propagation)) {
252+
return;
159253
}
160254

161255
// Run Kalman Gain Updater
@@ -173,10 +267,10 @@ struct kalman_actor : detray::actor {
173267
// Forward filter
174268
res = gain_matrix_updater<algebra_t>{}(
175269
trk_state, actor_state.m_measurements,
176-
propagation._stepping.bound_params(), is_line);
270+
bound_param, is_line);
177271

178272
// Update the propagation flow
179-
stepping.bound_params() = trk_state.filtered_params();
273+
bound_param = trk_state.filtered_params();
180274
} else {
181275
assert(false);
182276
}
@@ -185,10 +279,15 @@ struct kalman_actor : detray::actor {
185279
kalman_actor_direction::BACKWARD_ONLY ||
186280
direction_e ==
187281
kalman_actor_direction::BIDIRECTIONAL) {
282+
// Forward filter did not find this state: skip
283+
if (trk_state.filtered().is_invalid()) {
284+
actor_state.next();
285+
return;
286+
}
188287
// Backward filter for smoothing
189288
res = two_filters_smoother<algebra_t>{}(
190289
trk_state, actor_state.m_measurements,
191-
propagation._stepping.bound_params(), is_line);
290+
bound_param, is_line);
192291
} else {
193292
assert(false);
194293
}
@@ -205,8 +304,7 @@ struct kalman_actor : detray::actor {
205304
// is changed (This rarely happens when qop is set with a poor seed
206305
// resolution)
207306
propagation.set_particle(detail::correct_particle_hypothesis(
208-
stepping.particle_hypothesis(),
209-
propagation._stepping.bound_params()));
307+
stepping.particle_hypothesis(), bound_param));
210308

211309
// Update iterator
212310
actor_state.next();

0 commit comments

Comments
 (0)