Skip to content

Commit 5b7bd1b

Browse files
committed
Optimize gather_tracks
1 parent ca5971a commit 5b7bd1b

File tree

3 files changed

+11
-6
lines changed

3 files changed

+11
-6
lines changed

device/common/include/traccc/ambiguity_resolution/device/gather_tracks.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ struct gather_tracks_payload {
4141
*/
4242
vecmem::data::vector_view<unsigned int> sorted_ids_view;
4343

44+
/**
45+
* @brief View object to the updated track
46+
*/
47+
vecmem::data::vector_view<unsigned int> updated_tracks_view;
48+
4449
/**
4550
* @brief View object to the whether track id is updated
4651
*/

device/cuda/src/ambiguity_resolution/greedy_ambiguity_resolution_algorithm.cu

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -412,9 +412,6 @@ greedy_ambiguity_resolution_algorithm::operator()(
412412
unsigned int nBlocks_warp =
413413
(n_accepted + nThreads_warp - 1) / nThreads_warp;
414414

415-
unsigned int nThreads_full = 1024;
416-
unsigned int nBlocks_full = (n_tracks + 1023) / 1024;
417-
418415
// Compute the threadblock dimension for scanning kernels
419416
auto compute_scan_config = [&](unsigned int n_accepted) {
420417
unsigned int nThreads_scan = m_warp_size * 4;
@@ -580,13 +577,14 @@ greedy_ambiguity_resolution_algorithm::operator()(
580577
.temp_sorted_ids_view = temp_sorted_ids_buffer,
581578
});
582579

583-
kernels::gather_tracks<<<nBlocks_full, nThreads_full, 0, stream>>>(
580+
kernels::gather_tracks<<<nBlocks_adaptive, nThreads_adaptive, 0, stream>>>(
584581
device::gather_tracks_payload{
585582
.terminate = terminate_device.get(),
586583
.n_accepted = n_accepted_device.get(),
587584
.n_updated_tracks = n_updated_tracks_device.get(),
588585
.temp_sorted_ids_view = temp_sorted_ids_buffer,
589586
.sorted_ids_view = sorted_ids_buffer,
587+
.updated_tracks_view = updated_tracks_buffer,
590588
.is_updated_view = is_updated_buffer});
591589

592590
cudaStreamEndCapture(stream, &graph);

device/cuda/src/ambiguity_resolution/kernels/gather_tracks.cu

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,16 @@ __global__ void gather_tracks(device::gather_tracks_payload payload) {
2323
vecmem::device_vector<const unsigned int> temp_sorted_ids(
2424
payload.temp_sorted_ids_view);
2525
vecmem::device_vector<unsigned int> sorted_ids(payload.sorted_ids_view);
26+
vecmem::device_vector<unsigned int> updated_tracks(
27+
payload.updated_tracks_view);
2628
vecmem::device_vector<int> is_updated(payload.is_updated_view);
2729

2830
auto globalIndex = threadIdx.x + blockIdx.x * blockDim.x;
2931
const unsigned int n_accepted = *(payload.n_accepted);
3032

3133
// Reset is_updated vector
32-
if (globalIndex < is_updated.size()) {
33-
is_updated[globalIndex] = 0;
34+
if (globalIndex < *(payload.n_updated_tracks)) {
35+
is_updated[updated_tracks[globalIndex]] = 0;
3436
}
3537

3638
if (globalIndex >= n_accepted) {

0 commit comments

Comments
 (0)