Skip to content

Commit 748fe50

Browse files
authored
Merge pull request #1046 from stephenswat/bug/duplicate_fill_kernel
Deduplicate sorting key filling CUDA kernels
2 parents 673d408 + 91e11ec commit 748fe50

17 files changed

+106
-91
lines changed

device/alpaka/src/finding/combinatorial_kalman_filter.hpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
#include "traccc/finding/details/combinatorial_kalman_filter_types.hpp"
2222
#include "traccc/finding/device/apply_interaction.hpp"
2323
#include "traccc/finding/device/build_tracks.hpp"
24-
#include "traccc/finding/device/fill_sort_keys.hpp"
24+
#include "traccc/finding/device/fill_finding_propagation_sort_keys.hpp"
2525
#include "traccc/finding/device/find_tracks.hpp"
2626
#include "traccc/finding/device/make_barcode_sequence.hpp"
2727
#include "traccc/finding/device/propagate_to_next_surface.hpp"
@@ -105,15 +105,18 @@ struct find_tracks {
105105
}
106106
};
107107

108-
/// Alpaka kernel functor for @c traccc::device::fill_sort_keys
109-
struct fill_sort_keys {
108+
/// Alpaka kernel functor for @c
109+
/// traccc::device::fill_finding_propagation_sort_keys
110+
struct fill_finding_propagation_sort_keys {
110111
template <typename TAcc>
111112
ALPAKA_FN_ACC void operator()(
112-
TAcc const& acc, const device::fill_sort_keys_payload payload) const {
113+
TAcc const& acc,
114+
const device::fill_finding_propagation_sort_keys_payload payload)
115+
const {
113116

114117
const device::global_index_t globalThreadIdx =
115118
::alpaka::getIdx<::alpaka::Grid, ::alpaka::Threads>(acc)[0];
116-
device::fill_sort_keys(globalThreadIdx, payload);
119+
device::fill_finding_propagation_sort_keys(globalThreadIdx, payload);
117120
}
118121
};
119122

@@ -428,8 +431,9 @@ combinatorial_kalman_filter(
428431
auto workDiv = makeWorkDiv<Acc>(blocksPerGrid, threadsPerBlock);
429432

430433
::alpaka::exec<Acc>(
431-
queue, workDiv, kernels::fill_sort_keys{},
432-
device::fill_sort_keys_payload{
434+
queue, workDiv,
435+
kernels::fill_finding_propagation_sort_keys{},
436+
device::fill_finding_propagation_sort_keys_payload{
433437
in_params_buffer, keys_buffer, param_ids_buffer});
434438
::alpaka::wait(queue);
435439

device/alpaka/src/fitting/kalman_fitting.hpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
#include "traccc/edm/track_candidate_container.hpp"
1616
#include "traccc/edm/track_state.hpp"
1717
#include "traccc/fitting/details/kalman_fitting_types.hpp"
18-
#include "traccc/fitting/device/fill_sort_keys.hpp"
18+
#include "traccc/fitting/device/fill_fitting_sort_keys.hpp"
1919
#include "traccc/fitting/device/fit.hpp"
2020
#include "traccc/fitting/device/fit_backward.hpp"
2121
#include "traccc/fitting/device/fit_forward.hpp"
@@ -33,8 +33,8 @@
3333
namespace traccc::alpaka::details {
3434
namespace kernels {
3535

36-
/// Alpaka kernel functor for @c traccc::device::fill_sort_keys
37-
struct fill_sort_keys {
36+
/// Alpaka kernel functor for @c traccc::device::fill_fitting_sort_keys
37+
struct fill_fitting_sort_keys {
3838
template <typename TAcc>
3939
ALPAKA_FN_ACC void operator()(
4040
TAcc const& acc,
@@ -45,8 +45,8 @@ struct fill_sort_keys {
4545

4646
const device::global_index_t globalThreadIdx =
4747
::alpaka::getIdx<::alpaka::Grid, ::alpaka::Threads>(acc)[0];
48-
device::fill_sort_keys(globalThreadIdx, track_candidates_view,
49-
keys_view, ids_view);
48+
device::fill_fitting_sort_keys(globalThreadIdx, track_candidates_view,
49+
keys_view, ids_view);
5050
}
5151
};
5252

@@ -190,9 +190,10 @@ track_state_container_types::buffer kalman_fitting(
190190
const auto workDiv = makeWorkDiv<Acc>(blocksPerGrid, threadsPerBlock);
191191

192192
// Fill the keys and param_ids buffers.
193-
::alpaka::exec<Acc>(
194-
queue, workDiv, kernels::fill_sort_keys{}, track_candidates_view.tracks,
195-
vecmem::get_data(keys_buffer), vecmem::get_data(param_ids_buffer));
193+
::alpaka::exec<Acc>(queue, workDiv, kernels::fill_fitting_sort_keys{},
194+
track_candidates_view.tracks,
195+
vecmem::get_data(keys_buffer),
196+
vecmem::get_data(param_ids_buffer));
196197
::alpaka::wait(queue);
197198

198199
// Sort the key to get the sorted parameter ids

device/common/CMakeLists.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,20 +61,20 @@ traccc_add_library( traccc_device_common device_common TYPE INTERFACE
6161
"include/traccc/finding/device/apply_interaction.hpp"
6262
"include/traccc/finding/device/build_tracks.hpp"
6363
"include/traccc/finding/device/find_tracks.hpp"
64-
"include/traccc/finding/device/fill_sort_keys.hpp"
64+
"include/traccc/finding/device/fill_finding_propagation_sort_keys.hpp"
6565
"include/traccc/finding/device/make_barcode_sequence.hpp"
6666
"include/traccc/finding/device/propagate_to_next_surface.hpp"
6767
"include/traccc/finding/device/impl/apply_interaction.ipp"
6868
"include/traccc/finding/device/impl/build_tracks.ipp"
6969
"include/traccc/finding/device/impl/find_tracks.ipp"
70-
"include/traccc/finding/device/impl/fill_sort_keys.ipp"
70+
"include/traccc/finding/device/impl/fill_finding_propagation_sort_keys.ipp"
7171
"include/traccc/finding/device/impl/make_barcode_sequence.ipp"
7272
"include/traccc/finding/device/impl/propagate_to_next_surface.ipp"
7373
# Track fitting funtions(s).
7474
"include/traccc/fitting/device/fit.hpp"
7575
"include/traccc/fitting/device/impl/fit.ipp"
76-
"include/traccc/fitting/device/fill_sort_keys.hpp"
77-
"include/traccc/fitting/device/impl/fill_sort_keys.ipp"
76+
"include/traccc/fitting/device/fill_finding_propagation_sort_keys.hpp"
77+
"include/traccc/fitting/device/impl/fill_finding_propagation_sort_keys.ipp"
7878
)
7979
target_link_libraries( traccc_device_common
8080
INTERFACE traccc::core vecmem::core )

device/common/include/traccc/finding/device/fill_sort_keys.hpp renamed to device/common/include/traccc/finding/device/fill_finding_propagation_sort_keys.hpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@
1919

2020
namespace traccc::device {
2121

22-
/// (Event Data) Payload for the @c traccc::device::fill_sort_keys function
23-
struct fill_sort_keys_payload {
22+
/// (Event Data) Payload for the @c
23+
/// traccc::device::fill_finding_propagation_sort_keys function
24+
struct fill_finding_propagation_sort_keys_payload {
2425
/**
2526
* @brief View object to the vector of bound track parameters
2627
*/
@@ -43,10 +44,11 @@ struct fill_sort_keys_payload {
4344
/// @param[in] globalIndex The index of the current thread
4445
/// @param[inout] payload The function call payload
4546
///
46-
TRACCC_HOST_DEVICE inline void fill_sort_keys(
47-
global_index_t globalIndex, const fill_sort_keys_payload& payload);
47+
TRACCC_HOST_DEVICE inline void fill_finding_propagation_sort_keys(
48+
global_index_t globalIndex,
49+
const fill_finding_propagation_sort_keys_payload& payload);
4850

4951
} // namespace traccc::device
5052

5153
// Include the implementation.
52-
#include "./impl/fill_sort_keys.ipp"
54+
#include "./impl/fill_finding_propagation_sort_keys.ipp"

device/common/include/traccc/finding/device/impl/fill_sort_keys.ipp renamed to device/common/include/traccc/finding/device/impl/fill_finding_propagation_sort_keys.ipp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99

1010
namespace traccc::device {
1111

12-
TRACCC_HOST_DEVICE inline void fill_sort_keys(
13-
const global_index_t globalIndex, const fill_sort_keys_payload& payload) {
12+
TRACCC_HOST_DEVICE inline void fill_finding_propagation_sort_keys(
13+
const global_index_t globalIndex,
14+
const fill_finding_propagation_sort_keys_payload& payload) {
1415

1516
const bound_track_parameters_collection_types::const_device params(
1617
payload.params_view);

device/common/include/traccc/fitting/device/fill_sort_keys.hpp renamed to device/common/include/traccc/fitting/device/fill_fitting_sort_keys.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ namespace traccc::device {
2323
/// @param[out] keys_view The key values
2424
/// @param[out] ids_view The param ids
2525
///
26-
TRACCC_HOST_DEVICE inline void fill_sort_keys(
26+
TRACCC_HOST_DEVICE inline void fill_fitting_sort_keys(
2727
global_index_t globalIndex,
2828
const edm::track_candidate_collection<default_algebra>::const_view&
2929
track_candidates_view,
@@ -33,4 +33,4 @@ TRACCC_HOST_DEVICE inline void fill_sort_keys(
3333
} // namespace traccc::device
3434

3535
// Include the implementation.
36-
#include "traccc/fitting/device/impl/fill_sort_keys.ipp"
36+
#include "traccc/fitting/device/impl/fill_fitting_sort_keys.ipp"

device/common/include/traccc/fitting/device/impl/fill_sort_keys.ipp renamed to device/common/include/traccc/fitting/device/impl/fill_fitting_sort_keys.ipp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
namespace traccc::device {
1111

12-
TRACCC_HOST_DEVICE inline void fill_sort_keys(
12+
TRACCC_HOST_DEVICE inline void fill_fitting_sort_keys(
1313
const global_index_t globalIndex,
1414
const edm::track_candidate_collection<default_algebra>::const_view&
1515
track_candidates_view,

device/cuda/CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ traccc_add_library( traccc_cuda cuda TYPE SHARED
6464
"src/finding/kernels/make_barcode_sequence.cu"
6565
"src/finding/kernels/make_barcode_sequence.cuh"
6666
"src/finding/kernels/apply_interaction.hpp"
67-
"src/finding/kernels/fill_sort_keys.cu"
68-
"src/finding/kernels/fill_sort_keys.cuh"
67+
"src/finding/kernels/fill_finding_propagation_sort_keys.cu"
68+
"src/finding/kernels/fill_finding_propagation_sort_keys.cuh"
6969
"src/finding/kernels/build_tracks.cu"
7070
"src/finding/kernels/build_tracks.cuh"
7171
"src/finding/kernels/find_tracks.cuh"
@@ -119,7 +119,7 @@ traccc_add_library( traccc_cuda cuda TYPE SHARED
119119
"src/fitting/kalman_fitting_algorithm_default_detector.cu"
120120
"src/fitting/kalman_fitting_algorithm_telescope_detector.cu"
121121
"src/fitting/kalman_fitting.cuh"
122-
"src/fitting/kernels/fill_sort_keys.cu"
122+
"src/fitting/kernels/fill_fitting_sort_keys.cu"
123123
"src/fitting/kernels/fit_prelude.cu"
124124
"src/fitting/kernels/specializations/fit_forward_constant_field_default_detector.cu"
125125
"src/fitting/kernels/specializations/fit_forward_constant_field_telescope_detector.cu"

device/cuda/src/finding/combinatorial_kalman_filter.cuh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
#include "../utils/utils.hpp"
1616
#include "./kernels/apply_interaction.hpp"
1717
#include "./kernels/build_tracks.cuh"
18-
#include "./kernels/fill_sort_keys.cuh"
18+
#include "./kernels/fill_finding_propagation_sort_keys.cuh"
1919
#include "./kernels/find_tracks.cuh"
2020
#include "./kernels/make_barcode_sequence.cuh"
2121
#include "./kernels/propagate_to_next_surface.hpp"
@@ -330,7 +330,8 @@ combinatorial_kalman_filter(
330330
const unsigned int nThreads = warp_size * 2;
331331
const unsigned int nBlocks =
332332
(n_candidates + nThreads - 1) / nThreads;
333-
kernels::fill_sort_keys<<<nBlocks, nThreads, 0, stream>>>(
333+
kernels::fill_finding_propagation_sort_keys<<<nBlocks, nThreads,
334+
0, stream>>>(
334335
{.params_view = in_params_buffer,
335336
.keys_view = keys_buffer,
336337
.ids_view = param_ids_buffer});
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
/** TRACCC library, part of the ACTS project (R&D line)
2+
*
3+
* (c) 2023-2025 CERN for the benefit of the ACTS project
4+
*
5+
* Mozilla Public License Version 2.0
6+
*/
7+
8+
// Local include(s).
9+
#include "../../utils/global_index.hpp"
10+
#include "fill_finding_propagation_sort_keys.cuh"
11+
12+
// Project include(s).
13+
#include "traccc/finding/device/fill_finding_propagation_sort_keys.hpp"
14+
15+
namespace traccc::cuda::kernels {
16+
17+
__global__ void fill_finding_propagation_sort_keys(
18+
device::fill_finding_propagation_sort_keys_payload payload) {
19+
20+
device::fill_finding_propagation_sort_keys(details::global_index1(),
21+
payload);
22+
}
23+
24+
} // namespace traccc::cuda::kernels

0 commit comments

Comments
 (0)