Skip to content

Commit 819e4d3

Browse files
committed
update to detray v0.104.0
1 parent b1557b1 commit 819e4d3

File tree

20 files changed

+156
-33
lines changed

20 files changed

+156
-33
lines changed

core/include/traccc/finding/actors/ckf_aborter.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ struct ckf_aborter : detray::actor {
3838
propagator_state_t &prop_state) const {
3939

4040
auto &navigation = prop_state._navigation;
41-
auto &stepping = prop_state._stepping;
41+
const auto &stepping = prop_state._stepping;
4242

4343
abrt_state.count++;
4444
abrt_state.path_from_surface += stepping.step_size();
@@ -47,7 +47,7 @@ struct ckf_aborter : detray::actor {
4747
if (navigation.is_on_sensitive() &&
4848
abrt_state.path_from_surface > abrt_state.min_step_length) {
4949
prop_state._heartbeat &= navigation.pause();
50-
abrt_state.success = true;
50+
abrt_state.success = navigation.is_alive();
5151
}
5252

5353
// Reset path from surface
@@ -60,6 +60,7 @@ struct ckf_aborter : detray::actor {
6060
prop_state._heartbeat &= navigation.abort(
6161
"CKF: Maximum number of steps to reach next sensitive surface "
6262
"exceeded");
63+
abrt_state.success = false;
6364
}
6465
}
6566
};

core/include/traccc/finding/candidate_link.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ struct candidate_link {
2828
// Index to the initial seed
2929
unsigned int seed_idx;
3030

31+
// How many candidates were added to the track
32+
unsigned int n_cand;
33+
3134
// How many times it skipped a surface
3235
unsigned int n_skipped;
3336

core/include/traccc/finding/details/combinatorial_kalman_filter.hpp

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,9 @@ combinatorial_kalman_filter(
139139

140140
std::vector<bound_track_parameters<algebra_type>> out_params;
141141

142+
std::vector<std::uint8_t> in_is_edge(seeds.size(), false);
143+
std::vector<std::uint8_t> out_is_edge;
144+
142145
for (unsigned int step = 0u; step < config.max_track_candidates_per_track;
143146
step++) {
144147

@@ -156,6 +159,7 @@ combinatorial_kalman_filter(
156159

157160
// Rough estimation on out parameters size
158161
out_params.reserve(n_in_params);
162+
out_is_edge.reserve(n_in_params);
159163

160164
// Previous step ID
161165
std::fill(n_trks_per_seed.begin(), n_trks_per_seed.end(), 0u);
@@ -169,13 +173,20 @@ combinatorial_kalman_filter(
169173
bound_track_parameters<algebra_type>& in_param =
170174
in_params[in_param_id];
171175

176+
const bool is_edge{in_is_edge[in_param_id] > 0u};
177+
172178
assert(!in_param.is_invalid());
173179

174180
const unsigned int orig_param_id =
175181
(step == 0
176182
? in_param_id
177183
: links[step - 1][param_to_link[step - 1][in_param_id]]
178184
.seed_idx);
185+
const unsigned int n_cand =
186+
(step == 0
187+
? 0
188+
: links[step - 1][param_to_link[step - 1][in_param_id]]
189+
.n_cand);
179190
const unsigned int skip_counter =
180191
(step == 0
181192
? 0
@@ -261,6 +272,7 @@ combinatorial_kalman_filter(
261272
.previous_candidate_idx = in_param_id,
262273
.meas_idx = meas_id,
263274
.seed_idx = orig_param_id,
275+
.n_cand = n_cand + 1,
264276
.n_skipped = skip_counter,
265277
.chi2 = chi2,
266278
.chi2_sum = prev_chi2_sum + chi2,
@@ -306,7 +318,8 @@ combinatorial_kalman_filter(
306318
.previous_candidate_idx = in_param_id,
307319
.meas_idx = std::numeric_limits<unsigned int>::max(),
308320
.seed_idx = orig_param_id,
309-
.n_skipped = skip_counter + 1,
321+
.n_cand = n_cand,
322+
.n_skipped = is_edge ? skip_counter : skip_counter + 1,
310323
.chi2 = std::numeric_limits<traccc::scalar>::max(),
311324
.chi2_sum = prev_chi2_sum,
312325
.ndf_sum = prev_ndf_sum});
@@ -356,7 +369,7 @@ combinatorial_kalman_filter(
356369
prob(Lthisbase.chi2_sum,
357370
static_cast<scalar>(Lthisbase.ndf_sum - 5));
358371

359-
if (step + 1 - Lthisbase.n_skipped <=
372+
if (Lthisbase.n_cand <=
360373
config.duplicate_removal_minimum_length ||
361374
Lthisbase.ndf_sum <= 5) {
362375
continue;
@@ -370,7 +383,7 @@ combinatorial_kalman_filter(
370383
auto Lthis = Lthisbase;
371384
auto Lthat = links.at(step).at(tracks.at(j));
372385

373-
if (step + 1 - Lthat.n_skipped <=
386+
if (Lthat.n_cand <=
374387
config.duplicate_removal_minimum_length ||
375388
Lthisbase.ndf_sum <= 5) {
376389
continue;
@@ -495,6 +508,8 @@ combinatorial_kalman_filter(
495508
assert(!propagation._stepping.bound_params().is_invalid());
496509

497510
out_params.push_back(propagation._stepping.bound_params());
511+
out_is_edge.push_back(
512+
propagation._navigation.is_edge_candidate());
498513
param_to_link[step].push_back(link_id);
499514
}
500515
// Unless the track found a surface, it is considered a
@@ -513,7 +528,9 @@ combinatorial_kalman_filter(
513528
}
514529

515530
in_params = std::move(out_params);
531+
in_is_edge = std::move(out_is_edge);
516532
out_params.clear();
533+
out_is_edge.clear();
517534
}
518535

519536
/**********************
@@ -529,7 +546,7 @@ combinatorial_kalman_filter(
529546
// Get the link corresponding to tip
530547
auto L = links.at(tip.first).at(tip.second);
531548

532-
const unsigned int n_cands = tip.first + 1 - L.n_skipped;
549+
const unsigned int n_cands = L.n_cand;
533550

534551
// Skip if the number of tracks candidates is too small
535552
if (n_cands < config.min_track_candidates_per_track ||

core/include/traccc/fitting/details/kalman_fitting.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ typename edm::track_fit_container<algebra_t>::host kalman_fitting(
9191
result_tracks_device.at(result_tracks_device.size() - 1),
9292
typename edm::track_state_collection<algebra_t>::device{
9393
vecmem::get_data(result.states)},
94-
measurements, seqs_buffer);
94+
measurements, seqs_buffer, fitter.config().propagation);
9595

9696
// Run the fitter. The status that it returns is not used here. The main
9797
// failure modes are saved onto the fitted track itself. Not sure what

core/include/traccc/fitting/kalman_filter/kalman_fitter.hpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,28 +110,31 @@ class kalman_fitter {
110110
track_states,
111111
const measurement_collection_types::const_device& measurements,
112112
vecmem::data::vector_view<detray::geometry::barcode>
113-
sequence_buffer)
113+
sequence_buffer,
114+
const detray::propagation::config& prop_cfg)
114115
: m_fit_actor_state{track, track_states, measurements},
115116
m_sequencer_state(
116117
vecmem::device_vector<detray::geometry::barcode>(
117118
sequence_buffer)),
119+
m_parameter_resetter{prop_cfg},
118120
m_fit_res{track},
119121
m_sequence_buffer(sequence_buffer) {}
120122

121123
/// @return the actor chain state
122124
TRACCC_HOST_DEVICE
123125
typename forward_actor_chain_type::state_ref_tuple operator()() {
124126
return detray::tie(m_aborter_state, m_interactor_state,
125-
m_fit_actor_state, m_sequencer_state,
126-
m_step_aborter_state);
127+
m_fit_actor_state, m_parameter_resetter,
128+
m_sequencer_state, m_step_aborter_state);
127129
}
128130

129131
/// @return the actor chain state
130132
TRACCC_HOST_DEVICE
131133
typename backward_actor_chain_type::state_ref_tuple
132134
backward_actor_state() {
133135
return detray::tie(m_aborter_state, m_fit_actor_state,
134-
m_interactor_state, m_step_aborter_state);
136+
m_interactor_state, m_parameter_resetter,
137+
m_step_aborter_state);
135138
}
136139

137140
/// Individual actor states
@@ -140,6 +143,7 @@ class kalman_fitter {
140143
typename forward_fit_actor::state m_fit_actor_state;
141144
typename barcode_sequencer::state m_sequencer_state;
142145
kalman_step_aborter::state m_step_aborter_state{};
146+
typename resetter::state m_parameter_resetter{};
143147

144148
/// Fitting result per track
145149
typename edm::track_fit_collection<algebra_type>::device::proxy_type

device/alpaka/src/finding/combinatorial_kalman_filter.hpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,11 @@ combinatorial_kalman_filter(
236236
copy.setup(param_liveness_buffer)->wait();
237237
copy.memset(param_liveness_buffer, 1)->wait();
238238

239+
// On the first measurement, no parameter can be a hole
240+
vecmem::data::vector_buffer<std::uint8_t> is_edges_buffer(n_seeds, mr.main);
241+
copy.setup(is_edges_buffer)->wait();
242+
copy.memset(is_edges_buffer, 0)->wait();
243+
239244
// Number of tracks per seed
240245
vecmem::data::vector_buffer<unsigned int> n_tracks_per_seed_buffer(n_seeds,
241246
mr.main);
@@ -300,6 +305,11 @@ combinatorial_kalman_filter(
300305
n_max_candidates, mr.main);
301306
copy.setup(updated_liveness_buffer)->wait();
302307

308+
// Updated edges buffer after branching
309+
vecmem::data::vector_buffer<std::uint8_t> is_edges_updated_buffer(
310+
n_max_candidates, mr.main);
311+
copy.setup(is_edges_updated_buffer)->wait();
312+
303313
// Reset the number of tracks per seed
304314
copy.memset(n_tracks_per_seed_buffer, 0)->wait();
305315

@@ -343,6 +353,7 @@ combinatorial_kalman_filter(
343353
.in_params_view = in_params_buffer,
344354
.in_params_liveness_view = param_liveness_buffer,
345355
.n_in_params = n_in_params,
356+
.is_edges_view = is_edges_buffer,
346357
.measurement_ranges_view = meas_ranges_buffer,
347358
.links_view = links_buffer,
348359
.prev_links_idx =
@@ -506,6 +517,7 @@ combinatorial_kalman_filter(
506517
.params_view = in_params_buffer,
507518
.params_liveness_view = param_liveness_buffer,
508519
.param_ids_view = param_ids_buffer,
520+
.is_edges_view = is_edges_updated_buffer,
509521
.links_view = links_buffer,
510522
.prev_links_idx = step_to_link_idx_map[step],
511523
.step = step,
@@ -536,6 +548,8 @@ combinatorial_kalman_filter(
536548
bfield_t>{},
537549
config, device_payload.ptr());
538550
::alpaka::wait(queue);
551+
552+
std::swap(is_edges_buffer, is_edges_updated_buffer);
539553
}
540554
}
541555

device/common/include/traccc/finding/device/find_tracks.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@ struct find_tracks_payload {
5959
*/
6060
unsigned int n_in_params;
6161

62+
/**
63+
* @brief View object to the vector of booleans that flags if the sensor
64+
* was hit due to tolerance inflation
65+
*/
66+
vecmem::data::vector_view<std::uint8_t> is_edges_view;
67+
6268
/**
6369
* @brief View object to the vector of measurement index ranges per surface
6470
*/

device/common/include/traccc/finding/device/impl/find_tracks.ipp

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ TRACCC_HOST_DEVICE inline void find_tracks(
5757
payload.in_params_view);
5858
vecmem::device_vector<const unsigned int> in_params_liveness(
5959
payload.in_params_liveness_view);
60+
vecmem::device_vector<const std::uint8_t> is_edges(payload.is_edges_view);
6061
vecmem::device_vector<candidate_link> links(payload.links_view);
6162
vecmem::device_vector<candidate_link> tmp_links(payload.tmp_links_view);
6263
bound_track_parameters_collection_types::device tmp_params(
@@ -406,6 +407,8 @@ TRACCC_HOST_DEVICE inline void find_tracks(
406407
* Now, simply insert the temporary link at the found
407408
* position. Different cases for step 0 and other steps.
408409
*/
410+
const unsigned int n_cand =
411+
payload.step == 0 ? 0 : links.at(prev_link_idx).n_cand;
409412
const unsigned int n_skipped =
410413
payload.step == 0 ? 0
411414
: links.at(prev_link_idx).n_skipped;
@@ -423,6 +426,7 @@ TRACCC_HOST_DEVICE inline void find_tracks(
423426
.previous_candidate_idx = prev_link_idx,
424427
.meas_idx = meas_idx,
425428
.seed_idx = seed_idx,
429+
.n_cand = n_cand + 1,
426430
.n_skipped = n_skipped,
427431
.chi2 = chi2,
428432
.chi2_sum = prev_chi2_sum + chi2,
@@ -493,6 +497,7 @@ TRACCC_HOST_DEVICE inline void find_tracks(
493497

494498
unsigned int prev_link_idx = std::numeric_limits<unsigned int>::max();
495499
unsigned int seed_idx = std::numeric_limits<unsigned int>::max();
500+
unsigned int n_cand = 0u;
496501
unsigned int n_skipped = std::numeric_limits<unsigned int>::max();
497502
unsigned int prev_ndf_sum = 0u;
498503
scalar prev_chi2_sum = 0.f;
@@ -510,6 +515,7 @@ TRACCC_HOST_DEVICE inline void find_tracks(
510515
prev_link_idx = payload.prev_links_idx + in_param_id;
511516
seed_idx =
512517
payload.step > 0 ? links.at(prev_link_idx).seed_idx : in_param_id;
518+
n_cand = payload.step == 0 ? 0 : links.at(prev_link_idx).n_cand;
513519
n_skipped = payload.step == 0 ? 0 : links.at(prev_link_idx).n_skipped;
514520
in_param_can_create_hole =
515521
(n_skipped < cfg.max_num_skipping_per_cand) && (!last_step);
@@ -581,12 +587,14 @@ TRACCC_HOST_DEVICE inline void find_tracks(
581587
const unsigned int out_offset =
582588
shared_payload.shared_out_offset + local_out_offset;
583589

590+
const bool is_edge{is_edges.at(in_param_id) > 0u};
584591
links.at(out_offset) = {
585592
.step = payload.step,
586593
.previous_candidate_idx = prev_link_idx,
587594
.meas_idx = std::numeric_limits<unsigned int>::max(),
588595
.seed_idx = seed_idx,
589-
.n_skipped = n_skipped + 1,
596+
.n_cand = n_cand,
597+
.n_skipped = is_edge ? n_skipped : n_skipped + 1,
590598
.chi2 = std::numeric_limits<traccc::scalar>::max(),
591599
.chi2_sum = prev_chi2_sum,
592600
.ndf_sum = prev_ndf_sum};
@@ -597,11 +605,9 @@ TRACCC_HOST_DEVICE inline void find_tracks(
597605
out_params_liveness.at(param_pos) =
598606
static_cast<unsigned int>(!last_step);
599607
} else {
600-
const unsigned int n_cands = payload.step - n_skipped;
601-
602-
if (n_cands >= cfg.min_track_candidates_per_track) {
608+
if (n_cand >= cfg.min_track_candidates_per_track) {
603609
auto tip_pos = tips.push_back(prev_link_idx);
604-
tip_lengths.at(tip_pos) = n_cands;
610+
tip_lengths.at(tip_pos) = n_cand;
605611
}
606612
}
607613
} else {
@@ -620,12 +626,9 @@ TRACCC_HOST_DEVICE inline void find_tracks(
620626
static_cast<unsigned int>(!last_step);
621627
links.at(out_offset) = tmp_links.at(in_offset);
622628

623-
const unsigned int n_cands = payload.step + 1 - n_skipped;
624-
625-
if (last_step &&
626-
n_cands >= cfg.min_track_candidates_per_track) {
629+
if (last_step && n_cand >= cfg.min_track_candidates_per_track) {
627630
auto tip_pos = tips.push_back(out_offset);
628-
tip_lengths.at(tip_pos) = n_cands;
631+
tip_lengths.at(tip_pos) = n_cand;
629632
}
630633
}
631634
}

device/common/include/traccc/finding/device/impl/propagate_to_next_surface.ipp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,15 @@ TRACCC_HOST_DEVICE inline void propagate_to_next_surface(
3838
const unsigned int link_idx = payload.prev_links_idx + param_id;
3939
const auto& link = links.at(link_idx);
4040
assert(link.step == payload.step);
41-
const unsigned int n_cands = link.step + 1 - link.n_skipped;
41+
const unsigned int n_cands = link.n_cand;
4242

4343
// Parameter liveness
4444
vecmem::device_vector<unsigned int> params_liveness(
4545
payload.params_liveness_view);
4646

47+
// Whether the surface was hit in the tolerance band
48+
vecmem::device_vector<std::uint8_t> is_edges(payload.is_edges_view);
49+
4750
// tips
4851
vecmem::device_vector<unsigned int> tips(payload.tips_view);
4952
vecmem::device_vector<unsigned int> tip_lengths(payload.tip_lengths_view);
@@ -101,6 +104,7 @@ TRACCC_HOST_DEVICE inline void propagate_to_next_surface(
101104

102105
params[param_id] = propagation._stepping.bound_params();
103106
params_liveness[param_id] = 1u;
107+
is_edges[param_id] = propagation._navigation.is_edge_candidate();
104108
} else {
105109
params_liveness[param_id] = 0u;
106110

device/common/include/traccc/finding/device/propagate_to_next_surface.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ struct propagate_to_next_surface_payload {
5050
*/
5151
vecmem::data::vector_view<const unsigned int> param_ids_view;
5252

53+
/**
54+
* @brief View object to the vector of booleans that flags if the sensor
55+
* was hit due to tolerance inflation
56+
*/
57+
vecmem::data::vector_view<std::uint8_t> is_edges_view;
58+
5359
/**
5460
* @brief View object to the vector of candidate links
5561
*/

0 commit comments

Comments
 (0)