Skip to content

Commit 71ef2b5

Browse files
authored
CKF Track Deduplication, main branch (2025.07.15.) (#1078)
* Removed CUDA formalism from common code. * Implemented track deduplication for SYCL. * Implemented track deduplication for Alpaka. * Allow SYCL CKF test to use track deduplication.
1 parent df5cf48 commit 71ef2b5

File tree

4 files changed

+212
-6
lines changed

4 files changed

+212
-6
lines changed

device/alpaka/src/finding/combinatorial_kalman_filter.hpp

Lines changed: 120 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@
2121
#include "traccc/finding/details/combinatorial_kalman_filter_types.hpp"
2222
#include "traccc/finding/device/apply_interaction.hpp"
2323
#include "traccc/finding/device/build_tracks.hpp"
24+
#include "traccc/finding/device/fill_finding_duplicate_removal_sort_keys.hpp"
2425
#include "traccc/finding/device/fill_finding_propagation_sort_keys.hpp"
2526
#include "traccc/finding/device/find_tracks.hpp"
2627
#include "traccc/finding/device/make_barcode_sequence.hpp"
2728
#include "traccc/finding/device/propagate_to_next_surface.hpp"
29+
#include "traccc/finding/device/remove_duplicates.hpp"
2830
#include "traccc/finding/finding_config.hpp"
2931
#include "traccc/utils/logging.hpp"
3032
#include "traccc/utils/memory_resource.hpp"
@@ -105,8 +107,37 @@ struct find_tracks {
105107
}
106108
};
107109

108-
/// Alpaka kernel functor for @c
109-
/// traccc::device::fill_finding_propagation_sort_keys
110+
/// Alpaka kernel functor for
111+
/// @c traccc::device::fill_finding_duplicate_removal_sort_keys
112+
struct fill_finding_duplicate_removal_sort_keys {
113+
template <typename TAcc>
114+
ALPAKA_FN_ACC void operator()(
115+
TAcc const& acc,
116+
const device::fill_finding_duplicate_removal_sort_keys_payload& payload)
117+
const {
118+
119+
const device::global_index_t globalThreadIdx =
120+
::alpaka::getIdx<::alpaka::Grid, ::alpaka::Threads>(acc)[0];
121+
device::fill_finding_duplicate_removal_sort_keys(globalThreadIdx,
122+
payload);
123+
}
124+
};
125+
126+
/// Alpaka kernel functor for @c traccc::device::remove_duplicates
127+
struct remove_duplicates {
128+
template <typename TAcc>
129+
ALPAKA_FN_ACC void operator()(
130+
TAcc const& acc, const finding_config& cfg,
131+
const device::remove_duplicates_payload& payload) const {
132+
133+
const device::global_index_t globalThreadIdx =
134+
::alpaka::getIdx<::alpaka::Grid, ::alpaka::Threads>(acc)[0];
135+
device::remove_duplicates(globalThreadIdx, cfg, payload);
136+
}
137+
};
138+
139+
/// Alpaka kernel functor for
140+
/// @c traccc::device::fill_finding_propagation_sort_keys
110141
struct fill_finding_propagation_sort_keys {
111142
template <typename TAcc>
112143
ALPAKA_FN_ACC void operator()(
@@ -408,6 +439,74 @@ combinatorial_kalman_filter(
408439
step_to_link_idx_map[step + 1] - step_to_link_idx_map[step];
409440
}
410441

442+
/*
443+
* On later steps, we can duplicate removal which will attempt to find
444+
* tracks that are propagated multiple times and deduplicate them.
445+
*/
446+
if (n_candidates > 0 &&
447+
step >= config.duplicate_removal_minimum_length) {
448+
vecmem::data::vector_buffer<unsigned int>
449+
link_last_measurement_buffer(n_candidates, mr.main);
450+
vecmem::data::vector_buffer<unsigned int> param_ids_buffer(
451+
n_candidates, mr.main);
452+
453+
/*
454+
* First, we sort the tracks by the index of their final
455+
* measurement which is critical to ensure good performance.
456+
*/
457+
{
458+
const unsigned int nThreads = 256;
459+
const unsigned int nBlocks =
460+
(n_candidates + nThreads - 1) / nThreads;
461+
const auto workDiv = makeWorkDiv<Acc>(nBlocks, nThreads);
462+
463+
::alpaka::exec<Acc>(
464+
queue, workDiv,
465+
kernels::fill_finding_duplicate_removal_sort_keys{},
466+
device::fill_finding_duplicate_removal_sort_keys_payload{
467+
.links_view = links_buffer,
468+
.param_liveness_view = param_liveness_buffer,
469+
.link_last_measurement_view =
470+
link_last_measurement_buffer,
471+
.param_ids_view = param_ids_buffer,
472+
.n_links = n_candidates,
473+
.curr_links_idx = step_to_link_idx_map[step],
474+
.n_measurements = n_measurements});
475+
::alpaka::wait(queue);
476+
}
477+
478+
vecmem::device_vector<unsigned int> keys_device(
479+
link_last_measurement_buffer);
480+
vecmem::device_vector<unsigned int> param_ids_device(
481+
param_ids_buffer);
482+
thrust::sort_by_key(thrustExecPolicy, keys_device.begin(),
483+
keys_device.end(), param_ids_device.begin());
484+
485+
/*
486+
* Then, we run the actual duplicate removal kernel.
487+
*/
488+
{
489+
const unsigned int nThreads = 256;
490+
const unsigned int nBlocks =
491+
(n_candidates + nThreads - 1) / nThreads;
492+
const auto workDiv = makeWorkDiv<Acc>(nBlocks, nThreads);
493+
494+
::alpaka::exec<Acc>(
495+
queue, workDiv, kernels::remove_duplicates{}, config,
496+
device::remove_duplicates_payload{
497+
.links_view = links_buffer,
498+
.link_last_measurement_view =
499+
link_last_measurement_buffer,
500+
.param_ids_view = param_ids_buffer,
501+
.param_liveness_view = param_liveness_buffer,
502+
.n_links = n_candidates,
503+
.curr_links_idx = step_to_link_idx_map[step],
504+
.n_measurements = n_measurements,
505+
.step = step});
506+
::alpaka::wait(queue);
507+
}
508+
}
509+
411510
if (step == config.max_track_candidates_per_track - 1) {
412511
break;
413512
}
@@ -566,3 +665,22 @@ struct BlockSharedMemDynSizeBytes<
566665
};
567666

568667
} // namespace alpaka::trait
668+
669+
namespace alpaka {
670+
671+
/// Convince Alpaka that
672+
/// @c traccc::device::fill_finding_duplicate_removal_sort_keys_payload
673+
/// is trivially copyable
674+
template <>
675+
struct IsKernelArgumentTriviallyCopyable<
676+
traccc::device::fill_finding_duplicate_removal_sort_keys_payload, void>
677+
: std::true_type {};
678+
679+
/// Convince Alpaka that
680+
/// @c traccc::device::remove_duplicates_payload
681+
/// is trivially copyable
682+
template <>
683+
struct IsKernelArgumentTriviallyCopyable<
684+
traccc::device::remove_duplicates_payload, void> : std::true_type {};
685+
686+
} // namespace alpaka

device/common/include/traccc/finding/device/remove_duplicates.hpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ struct remove_duplicates_payload {
9494
* increasingly less likely that they will diverge afterwards.
9595
*/
9696
TRACCC_HOST_DEVICE inline void remove_duplicates(
97-
global_index_t gid, const finding_config& cfg,
97+
global_index_t tid, const finding_config& cfg,
9898
const remove_duplicates_payload& payload) {
9999

100100
const vecmem::device_vector<const candidate_link> links(payload.links_view);
@@ -105,8 +105,6 @@ TRACCC_HOST_DEVICE inline void remove_duplicates(
105105
const vecmem::device_vector<const unsigned int> param_ids(
106106
payload.param_ids_view);
107107

108-
const unsigned int tid = blockIdx.x * blockDim.x + threadIdx.x;
109-
110108
/*
111109
* As is standard fare, we ignore tracks that are out of bounds or that
112110
* have already been marked as "dead". Since this kernel contains no

device/sycl/src/finding/combinatorial_kalman_filter.hpp

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,12 @@
2424
#include "traccc/finding/details/combinatorial_kalman_filter_types.hpp"
2525
#include "traccc/finding/device/apply_interaction.hpp"
2626
#include "traccc/finding/device/build_tracks.hpp"
27+
#include "traccc/finding/device/fill_finding_duplicate_removal_sort_keys.hpp"
2728
#include "traccc/finding/device/fill_finding_propagation_sort_keys.hpp"
2829
#include "traccc/finding/device/find_tracks.hpp"
2930
#include "traccc/finding/device/make_barcode_sequence.hpp"
3031
#include "traccc/finding/device/propagate_to_next_surface.hpp"
32+
#include "traccc/finding/device/remove_duplicates.hpp"
3133
#include "traccc/finding/finding_config.hpp"
3234
#include "traccc/utils/memory_resource.hpp"
3335
#include "traccc/utils/projections.hpp"
@@ -49,6 +51,10 @@ struct apply_interaction {};
4951
template <typename T>
5052
struct find_tracks {};
5153
template <typename T>
54+
struct fill_finding_duplicate_removal_sort_keys {};
55+
template <typename T>
56+
struct remove_duplicates {};
57+
template <typename T>
5258
struct fill_finding_propagation_sort_keys {};
5359
template <typename T>
5460
struct propagate_to_next_surface {};
@@ -334,6 +340,91 @@ combinatorial_kalman_filter(
334340
step_to_link_idx_map[step + 1] - step_to_link_idx_map[step];
335341
}
336342

343+
/*
344+
* On later steps, we can duplicate removal which will attempt to find
345+
* tracks that are propagated multiple times and deduplicate them.
346+
*/
347+
if (n_candidates > 0 &&
348+
step >= config.duplicate_removal_minimum_length) {
349+
vecmem::data::vector_buffer<unsigned int>
350+
link_last_measurement_buffer(n_candidates, mr.main);
351+
vecmem::data::vector_buffer<unsigned int> param_ids_buffer(
352+
n_candidates, mr.main);
353+
354+
/*
355+
* First, we sort the tracks by the index of their final
356+
* measurement which is critical to ensure good performance.
357+
*/
358+
queue
359+
.submit([&](::sycl::handler& h) {
360+
h.parallel_for<
361+
kernels::fill_finding_duplicate_removal_sort_keys<
362+
kernel_t>>(
363+
calculate1DimNdRange(n_candidates, 256),
364+
[links_view = vecmem::get_data(links_buffer),
365+
param_liveness_view =
366+
vecmem::get_data(param_liveness_buffer),
367+
link_last_measurement_view =
368+
vecmem::get_data(link_last_measurement_buffer),
369+
param_ids_view = vecmem::get_data(param_ids_buffer),
370+
n_candidates,
371+
curr_links_idx = step_to_link_idx_map[step],
372+
n_measurements](::sycl::nd_item<1> item) {
373+
device::fill_finding_duplicate_removal_sort_keys(
374+
details::global_index(item),
375+
{.links_view = links_view,
376+
.param_liveness_view = param_liveness_view,
377+
.link_last_measurement_view =
378+
link_last_measurement_view,
379+
.param_ids_view = param_ids_view,
380+
.n_links = n_candidates,
381+
.curr_links_idx = curr_links_idx,
382+
.n_measurements = n_measurements});
383+
});
384+
})
385+
.wait_and_throw();
386+
387+
vecmem::device_vector<unsigned int> keys_device(
388+
link_last_measurement_buffer);
389+
vecmem::device_vector<unsigned int> param_ids_device(
390+
param_ids_buffer);
391+
oneapi::dpl::sort_by_key(policy, keys_device.begin(),
392+
keys_device.end(),
393+
param_ids_device.begin());
394+
queue.wait_and_throw();
395+
396+
/*
397+
* Then, we run the actual duplicate removal kernel.
398+
*/
399+
queue
400+
.submit([&](::sycl::handler& h) {
401+
h.parallel_for<kernels::remove_duplicates<kernel_t>>(
402+
calculate1DimNdRange(n_candidates, 256),
403+
[config, links_view = vecmem::get_data(links_buffer),
404+
link_last_measurement_view =
405+
vecmem::get_data(link_last_measurement_buffer),
406+
param_ids_view = vecmem::get_data(param_ids_buffer),
407+
param_liveness_view =
408+
vecmem::get_data(param_liveness_buffer),
409+
n_candidates,
410+
curr_links_idx = step_to_link_idx_map[step],
411+
n_measurements, step](::sycl::nd_item<1> item) {
412+
device::remove_duplicates(
413+
details::global_index(item), config,
414+
{.links_view = links_view,
415+
.link_last_measurement_view =
416+
link_last_measurement_view,
417+
.param_ids_view = param_ids_view,
418+
.param_liveness_view = param_liveness_view,
419+
.n_links = n_candidates,
420+
.curr_links_idx = curr_links_idx,
421+
.n_measurements = n_measurements,
422+
.step = step});
423+
});
424+
})
425+
.wait_and_throw();
426+
}
427+
337428
if (step == config.max_track_candidates_per_track - 1) {
338429
break;
339430
}

tests/sycl/test_ckf_toy_detector.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,6 @@ TEST_P(CkfToyDetectorTests, Run) {
138138
cfg.ptc_hypothesis = ptc;
139139
cfg.max_num_branches_per_seed = 500;
140140
cfg.propagation.navigation.search_window = search_window;
141-
cfg.duplicate_removal_minimum_length = 100u;
142141

143142
// Finding algorithm object
144143
traccc::host::combinatorial_kalman_filter_algorithm host_finding(cfg,

0 commit comments

Comments
 (0)