Skip to content

Commit 83b1b6a

Browse files
authored
Merge pull request #1145 from stephenswat/refactor/kalman_actor_state
Simplify iterators in Kálmán actor
2 parents b3e82ce + c10eb69 commit 83b1b6a

File tree

2 files changed

+30
-29
lines changed

2 files changed

+30
-29
lines changed

core/include/traccc/fitting/kalman_filter/kalman_actor.hpp

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -54,54 +54,55 @@ struct kalman_actor_state {
5454
TRACCC_HOST_DEVICE
5555
typename edm::track_state_collection<algebra_t>::device::proxy_type
5656
operator()() {
57-
if (!backward_mode) {
58-
return m_track_states.at(*m_it);
59-
} else {
60-
return m_track_states.at(*m_it_rev);
61-
}
57+
return m_track_states.at(m_track.state_indices().at(m_idx));
6258
}
6359

6460
/// Reset the iterator
6561
TRACCC_HOST_DEVICE
6662
void reset() {
67-
m_it = m_track.state_indices().begin();
68-
m_it_rev = m_track.state_indices().rbegin();
63+
if (!backward_mode) {
64+
m_idx = 0;
65+
} else {
66+
m_idx = m_track.state_indices().size() - 1;
67+
}
6968
}
7069

7170
/// Advance the iterator
7271
TRACCC_HOST_DEVICE
7372
void next() {
7473
if (!backward_mode) {
75-
m_it++;
74+
m_idx++;
7675
} else {
77-
m_it_rev++;
76+
m_idx--;
7877
}
7978
}
8079

8180
/// @return true if the iterator reaches the end of vector
8281
TRACCC_HOST_DEVICE
8382
bool is_complete() {
84-
if (!backward_mode && m_it == m_track.state_indices().end()) {
83+
if (!backward_mode && m_idx == m_track.state_indices().size()) {
8584
return true;
86-
} else if (backward_mode &&
87-
m_it_rev == m_track.state_indices().rend()) {
85+
} else if (backward_mode && m_idx > m_track.state_indices().size()) {
8886
return true;
8987
}
9088
return false;
9189
}
9290

91+
TRACCC_HOST_DEVICE
92+
bool is_state() {
93+
return m_track.state_indices().at(m_idx) !=
94+
std::numeric_limits<unsigned int>::max();
95+
}
96+
9397
/// Object describing the track fit
9498
typename edm::track_fit_collection<algebra_t>::device::proxy_type m_track;
9599
/// All track states in the event
96100
typename edm::track_state_collection<algebra_t>::device m_track_states;
97101
/// All measurements in the event
98102
measurement_collection_types::const_device m_measurements;
99103

100-
/// Iterator for forward filtering over the track states
101-
vecmem::device_vector<unsigned int>::iterator m_it;
102-
103-
/// Iterator for backward filtering over the track states
104-
vecmem::device_vector<unsigned int>::reverse_iterator m_it_rev;
104+
/// Index of the current track state
105+
unsigned int m_idx;
105106

106107
// The number of holes (The number of sensitive surfaces which do not
107108
// have a measurement for the track pattern)

core/include/traccc/fitting/kalman_filter/kalman_fitter.hpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -259,22 +259,23 @@ class kalman_fitter {
259259
return kalman_fitter_status::ERROR_BARCODE_SEQUENCE_OVERFLOW;
260260
}
261261

262-
auto& track = fitter_state.m_fit_actor_state.m_track;
263-
auto& track_states = fitter_state.m_fit_actor_state.m_track_states;
262+
fitter_state.m_fit_actor_state.backward_mode = true;
263+
fitter_state.m_fit_actor_state.reset();
264264

265265
// Since the smoothed track parameter of the last surface can be
266266
// considered to be the filtered one, we can reversly iterate the
267267
// algorithm to obtain the smoothed parameter of other surfaces
268-
for (auto it = track.state_indices().rbegin();
269-
it != track.state_indices().rend(); ++it) {
270-
if (!track_states.at(*it).is_hole()) {
271-
fitter_state.m_fit_actor_state.m_it_rev = it;
272-
break;
273-
}
274-
// TODO: Return false because there is no valid track state
275-
// return false;
268+
while (!fitter_state.m_fit_actor_state.is_complete() &&
269+
(!fitter_state.m_fit_actor_state.is_state() ||
270+
fitter_state.m_fit_actor_state().is_hole())) {
271+
fitter_state.m_fit_actor_state.next();
276272
}
277-
auto last = track_states.at(*fitter_state.m_fit_actor_state.m_it_rev);
273+
274+
if (fitter_state.m_fit_actor_state.is_complete()) {
275+
return kalman_fitter_status::SUCCESS;
276+
}
277+
278+
auto last = fitter_state.m_fit_actor_state();
278279

279280
const scalar theta = last.filtered_params().theta();
280281
if (theta <= 0.f || theta >= constant<traccc::scalar>::pi) {
@@ -322,7 +323,6 @@ class kalman_fitter {
322323

323324
propagation._navigation.set_direction(
324325
detray::navigation::direction::e_backward);
325-
fitter_state.m_fit_actor_state.backward_mode = true;
326326

327327
// Synchronize the current barcode with the input track parameter
328328
while (propagation._navigation.get_target_barcode() !=

0 commit comments

Comments
 (0)