Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 123 additions & 25 deletions core/include/traccc/fitting/kalman_filter/kalman_actor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,16 @@
#include "traccc/utils/particle.hpp"

// detray include(s).
#include <detray/navigation/direct_navigator.hpp>
#include <detray/navigation/navigator.hpp>
#include <detray/propagator/base_actor.hpp>

// vecmem include(s)
#include <vecmem/containers/device_vector.hpp>

// System include(s)
#include <cstdlib>

namespace traccc {

enum class kalman_actor_direction {
Expand Down Expand Up @@ -65,6 +70,7 @@ struct kalman_actor_state {
} else {
m_idx = m_track.state_indices().size() - 1;
}
n_holes = 0u;
}

/// Advance the iterator
Expand All @@ -80,11 +86,106 @@ struct kalman_actor_state {
/// @return true if the iterator reaches the end of vector
TRACCC_HOST_DEVICE
bool is_complete() {
if (!backward_mode && m_idx == m_track.state_indices().size()) {
return true;
} else if (backward_mode && m_idx > m_track.state_indices().size()) {
return true;
return ((!backward_mode && m_idx == m_track.state_indices().size()) ||
(backward_mode && m_idx > m_track.state_indices().size()));
}

/// @returns the current number of holes in this state
TRACCC_HOST_DEVICE
unsigned int count_missed() const {
unsigned int n_missed{0u};

if (backward_mode) {
for (const auto& trk_state : m_track_states) {
if (trk_state.smoothed().is_invalid()) {
++n_missed;
}
}
} else {
for (const auto& trk_state : m_track_states) {
if (trk_state.filtered().is_invalid()) {
++n_missed;
}
}
}

return n_missed;
}

template <typename nav_state_t>
void evaluate_hole(const nav_state_t& navigation) {
if (navigation.is_good_candidate()) {
++n_holes;
}
}

/// @return true if the iterator reaches the end of vector
/// @TODO: Remove once direct navigator is used in forward pass
template <typename propagation_state_t>
TRACCC_HOST_DEVICE bool check_matching_surface(
propagation_state_t& propagation) {

auto& navigation = propagation._navigation;
auto& trk_state = (*this)();

// Surface was found, continue with KF algorithm
if (navigation.barcode() == trk_state.surface_link()) {
// Count a hole, if track finding did not find a measurement
if (trk_state.is_hole) {
evaluate_hole(navigation);
}
// If track finding did not find measurement on this surface: skip
return !trk_state.is_hole;
}

// Skipped surfaces: adjust iterator and remove counted hole
// (only relevant if using non-direct navigation, e.g. forward truth
// fitting or different prop. config between CKF asnd KF)
// TODO: Remove again
using detector_t = typename propagation_state_t::detector_type;
using nav_state_t = typename propagation_state_t::navigator_state_type;
if constexpr (!std::same_as<nav_state_t,
typename detray::direct_navigator<
detector_t>::state>) {
int i{1};
if (backward_mode) {
// If we are on the last state and the navigation surface does
// not match, it must be an additional surface
// -> continue navigation until matched
if (m_it_rev + 1 == m_track_states.rend()) {
evaluate_hole(navigation);
return false;
}
// Check if the current navigation surfaces can be found on a
// later track state. That means the current track state was
// skipped by the navigator: Advance the internal iterator
for (auto itr = m_it_rev + 1; itr != m_track_states.rend();
++itr) {
if (itr->surface_link() == navigation.barcode()) {
m_it_rev += i;
return true;
}
++i;
}
} else {
if (m_it + 1 == m_track_states.end()) {
evaluate_hole(navigation);
return false;
}
for (auto itr = m_it + 1; itr != m_track_states.end(); ++itr) {
if (itr->surface_link() == navigation.barcode()) {
m_it += i;
return true;
}
++i;
}
}
}

// Mismatch was not from missed state: Is a hole
++n_holes;

// After additional surface, keep navigating until match is found
return false;
}

Expand All @@ -104,8 +205,7 @@ struct kalman_actor_state {
/// Index of the current track state
unsigned int m_idx;

// The number of holes (The number of sensitive surfaces which do not
// have a measurement for the track pattern)
// Count the number of encountered surfaces without measurement
unsigned int n_holes{0u};

// Run back filtering for smoothing, if true
Expand All @@ -130,7 +230,6 @@ struct kalman_actor : detray::actor {
auto& stepping = propagation._stepping;
auto& navigation = propagation._navigation;

// If the iterator reaches the end, terminate the propagation
if (actor_state.is_complete()) {
propagation._heartbeat &= navigation.exit();
return;
Expand All @@ -142,20 +241,15 @@ struct kalman_actor : detray::actor {
typename edm::track_state_collection<algebra_t>::device::proxy_type
trk_state = actor_state();

// Increase the hole counts if the propagator fails to find the next
// measurement
if (navigation.barcode() !=
actor_state.m_measurements.at(trk_state.measurement_index())
.surface_link) {
if (!actor_state.backward_mode) {
actor_state.n_holes++;
}
return;
}
// Did the navigation switch direction?
actor_state.backward_mode =
navigation.direction() ==
detray::navigation::direction::e_backward;

// This track state is not a hole
if (!actor_state.backward_mode) {
trk_state.set_hole(false);
// Increase the hole count if the propagator stops at an additional
// surface and wait for the next sensitive surface to match
if (!actor_state.check_matching_surface(propagation)) {
return;
}

// Run Kalman Gain Updater
Expand All @@ -173,10 +267,10 @@ struct kalman_actor : detray::actor {
// Forward filter
res = gain_matrix_updater<algebra_t>{}(
trk_state, actor_state.m_measurements,
propagation._stepping.bound_params(), is_line);
bound_param, is_line);

// Update the propagation flow
stepping.bound_params() = trk_state.filtered_params();
bound_param = trk_state.filtered_params();
} else {
assert(false);
}
Expand All @@ -185,10 +279,15 @@ struct kalman_actor : detray::actor {
kalman_actor_direction::BACKWARD_ONLY ||
direction_e ==
kalman_actor_direction::BIDIRECTIONAL) {
// Forward filter did not find this state: skip
if (trk_state.filtered().is_invalid()) {
actor_state.next();
return;
}
// Backward filter for smoothing
res = two_filters_smoother<algebra_t>{}(
trk_state, actor_state.m_measurements,
propagation._stepping.bound_params(), is_line);
bound_param, is_line);
} else {
assert(false);
}
Expand All @@ -205,8 +304,7 @@ struct kalman_actor : detray::actor {
// is changed (This rarely happens when qop is set with a poor seed
// resolution)
propagation.set_particle(detail::correct_particle_hypothesis(
stepping.particle_hypothesis(),
propagation._stepping.bound_params()));
stepping.particle_hypothesis(), bound_param));

// Update iterator
actor_state.next();
Expand Down
Loading