|
21 | 21 | #include "traccc/finding/details/combinatorial_kalman_filter_types.hpp" |
22 | 22 | #include "traccc/finding/device/apply_interaction.hpp" |
23 | 23 | #include "traccc/finding/device/build_tracks.hpp" |
| 24 | +#include "traccc/finding/device/fill_finding_duplicate_removal_sort_keys.hpp" |
24 | 25 | #include "traccc/finding/device/fill_finding_propagation_sort_keys.hpp" |
25 | 26 | #include "traccc/finding/device/find_tracks.hpp" |
26 | 27 | #include "traccc/finding/device/make_barcode_sequence.hpp" |
27 | 28 | #include "traccc/finding/device/propagate_to_next_surface.hpp" |
| 29 | +#include "traccc/finding/device/remove_duplicates.hpp" |
28 | 30 | #include "traccc/finding/finding_config.hpp" |
29 | 31 | #include "traccc/utils/logging.hpp" |
30 | 32 | #include "traccc/utils/memory_resource.hpp" |
@@ -105,8 +107,37 @@ struct find_tracks { |
105 | 107 | } |
106 | 108 | }; |
107 | 109 |
|
108 | | -/// Alpaka kernel functor for @c |
109 | | -/// traccc::device::fill_finding_propagation_sort_keys |
| 110 | +/// Alpaka kernel functor for |
| 111 | +/// @c traccc::device::fill_finding_duplicate_removal_sort_keys |
| 112 | +struct fill_finding_duplicate_removal_sort_keys { |
| 113 | + template <typename TAcc> |
| 114 | + ALPAKA_FN_ACC void operator()( |
| 115 | + TAcc const& acc, |
| 116 | + const device::fill_finding_duplicate_removal_sort_keys_payload& payload) |
| 117 | + const { |
| 118 | + |
| 119 | + const device::global_index_t globalThreadIdx = |
| 120 | + ::alpaka::getIdx<::alpaka::Grid, ::alpaka::Threads>(acc)[0]; |
| 121 | + device::fill_finding_duplicate_removal_sort_keys(globalThreadIdx, |
| 122 | + payload); |
| 123 | + } |
| 124 | +}; |
| 125 | + |
| 126 | +/// Alpaka kernel functor for @c traccc::device::remove_duplicates |
| 127 | +struct remove_duplicates { |
| 128 | + template <typename TAcc> |
| 129 | + ALPAKA_FN_ACC void operator()( |
| 130 | + TAcc const& acc, const finding_config& cfg, |
| 131 | + const device::remove_duplicates_payload& payload) const { |
| 132 | + |
| 133 | + const device::global_index_t globalThreadIdx = |
| 134 | + ::alpaka::getIdx<::alpaka::Grid, ::alpaka::Threads>(acc)[0]; |
| 135 | + device::remove_duplicates(globalThreadIdx, cfg, payload); |
| 136 | + } |
| 137 | +}; |
| 138 | + |
| 139 | +/// Alpaka kernel functor for |
| 140 | +/// @c traccc::device::fill_finding_propagation_sort_keys |
110 | 141 | struct fill_finding_propagation_sort_keys { |
111 | 142 | template <typename TAcc> |
112 | 143 | ALPAKA_FN_ACC void operator()( |
@@ -408,6 +439,74 @@ combinatorial_kalman_filter( |
408 | 439 | step_to_link_idx_map[step + 1] - step_to_link_idx_map[step]; |
409 | 440 | } |
410 | 441 |
|
| 442 | + /* |
| 443 | + * On later steps, we can duplicate removal which will attempt to find |
| 444 | + * tracks that are propagated multiple times and deduplicate them. |
| 445 | + */ |
| 446 | + if (n_candidates > 0 && |
| 447 | + step >= config.duplicate_removal_minimum_length) { |
| 448 | + vecmem::data::vector_buffer<unsigned int> |
| 449 | + link_last_measurement_buffer(n_candidates, mr.main); |
| 450 | + vecmem::data::vector_buffer<unsigned int> param_ids_buffer( |
| 451 | + n_candidates, mr.main); |
| 452 | + |
| 453 | + /* |
| 454 | + * First, we sort the tracks by the index of their final |
| 455 | + * measurement which is critical to ensure good performance. |
| 456 | + */ |
| 457 | + { |
| 458 | + const unsigned int nThreads = 256; |
| 459 | + const unsigned int nBlocks = |
| 460 | + (n_candidates + nThreads - 1) / nThreads; |
| 461 | + const auto workDiv = makeWorkDiv<Acc>(nBlocks, nThreads); |
| 462 | + |
| 463 | + ::alpaka::exec<Acc>( |
| 464 | + queue, workDiv, |
| 465 | + kernels::fill_finding_duplicate_removal_sort_keys{}, |
| 466 | + device::fill_finding_duplicate_removal_sort_keys_payload{ |
| 467 | + .links_view = links_buffer, |
| 468 | + .param_liveness_view = param_liveness_buffer, |
| 469 | + .link_last_measurement_view = |
| 470 | + link_last_measurement_buffer, |
| 471 | + .param_ids_view = param_ids_buffer, |
| 472 | + .n_links = n_candidates, |
| 473 | + .curr_links_idx = step_to_link_idx_map[step], |
| 474 | + .n_measurements = n_measurements}); |
| 475 | + ::alpaka::wait(queue); |
| 476 | + } |
| 477 | + |
| 478 | + vecmem::device_vector<unsigned int> keys_device( |
| 479 | + link_last_measurement_buffer); |
| 480 | + vecmem::device_vector<unsigned int> param_ids_device( |
| 481 | + param_ids_buffer); |
| 482 | + thrust::sort_by_key(thrustExecPolicy, keys_device.begin(), |
| 483 | + keys_device.end(), param_ids_device.begin()); |
| 484 | + |
| 485 | + /* |
| 486 | + * Then, we run the actual duplicate removal kernel. |
| 487 | + */ |
| 488 | + { |
| 489 | + const unsigned int nThreads = 256; |
| 490 | + const unsigned int nBlocks = |
| 491 | + (n_candidates + nThreads - 1) / nThreads; |
| 492 | + const auto workDiv = makeWorkDiv<Acc>(nBlocks, nThreads); |
| 493 | + |
| 494 | + ::alpaka::exec<Acc>( |
| 495 | + queue, workDiv, kernels::remove_duplicates{}, config, |
| 496 | + device::remove_duplicates_payload{ |
| 497 | + .links_view = links_buffer, |
| 498 | + .link_last_measurement_view = |
| 499 | + link_last_measurement_buffer, |
| 500 | + .param_ids_view = param_ids_buffer, |
| 501 | + .param_liveness_view = param_liveness_buffer, |
| 502 | + .n_links = n_candidates, |
| 503 | + .curr_links_idx = step_to_link_idx_map[step], |
| 504 | + .n_measurements = n_measurements, |
| 505 | + .step = step}); |
| 506 | + ::alpaka::wait(queue); |
| 507 | + } |
| 508 | + } |
| 509 | + |
411 | 510 | if (step == config.max_track_candidates_per_track - 1) { |
412 | 511 | break; |
413 | 512 | } |
@@ -566,3 +665,22 @@ struct BlockSharedMemDynSizeBytes< |
566 | 665 | }; |
567 | 666 |
|
568 | 667 | } // namespace alpaka::trait |
| 668 | + |
| 669 | +namespace alpaka { |
| 670 | + |
| 671 | +/// Convince Alpaka that |
| 672 | +/// @c traccc::device::fill_finding_duplicate_removal_sort_keys_payload |
| 673 | +/// is trivially copyable |
| 674 | +template <> |
| 675 | +struct IsKernelArgumentTriviallyCopyable< |
| 676 | + traccc::device::fill_finding_duplicate_removal_sort_keys_payload, void> |
| 677 | + : std::true_type {}; |
| 678 | + |
| 679 | +/// Convince Alpaka that |
| 680 | +/// @c traccc::device::remove_duplicates_payload |
| 681 | +/// is trivially copyable |
| 682 | +template <> |
| 683 | +struct IsKernelArgumentTriviallyCopyable< |
| 684 | + traccc::device::remove_duplicates_payload, void> : std::true_type {}; |
| 685 | + |
| 686 | +} // namespace alpaka |
0 commit comments