diff --git a/device/common/include/traccc/ambiguity_resolution/device/gather_tracks.hpp b/device/common/include/traccc/ambiguity_resolution/device/gather_tracks.hpp index c3246b5037..5b657df64c 100644 --- a/device/common/include/traccc/ambiguity_resolution/device/gather_tracks.hpp +++ b/device/common/include/traccc/ambiguity_resolution/device/gather_tracks.hpp @@ -41,6 +41,11 @@ struct gather_tracks_payload { */ vecmem::data::vector_view sorted_ids_view; + /** + * @brief View object to the updated track + */ + vecmem::data::vector_view updated_tracks_view; + /** * @brief View object to the whether track id is updated */ diff --git a/device/cuda/src/ambiguity_resolution/greedy_ambiguity_resolution_algorithm.cu b/device/cuda/src/ambiguity_resolution/greedy_ambiguity_resolution_algorithm.cu index 54d310bbf1..0b4e7744db 100644 --- a/device/cuda/src/ambiguity_resolution/greedy_ambiguity_resolution_algorithm.cu +++ b/device/cuda/src/ambiguity_resolution/greedy_ambiguity_resolution_algorithm.cu @@ -420,9 +420,6 @@ greedy_ambiguity_resolution_algorithm::operator()( unsigned int nBlocks_warp = (n_accepted + nThreads_warp - 1) / nThreads_warp; - unsigned int nThreads_full = 1024; - unsigned int nBlocks_full = (n_tracks + 1023) / 1024; - unsigned int nThreads_rearrange = 1024; unsigned int nBlocks_rearrange = (n_accepted + (nThreads_rearrange / kernels::nThreads_per_track) - 1) / @@ -591,14 +588,16 @@ greedy_ambiguity_resolution_algorithm::operator()( .temp_sorted_ids_view = temp_sorted_ids_buffer, }); - kernels::gather_tracks<<>>( - device::gather_tracks_payload{ - .terminate = terminate_device.get(), - .n_accepted = n_accepted_device.get(), - .n_updated_tracks = n_updated_tracks_device.get(), - .temp_sorted_ids_view = temp_sorted_ids_buffer, - .sorted_ids_view = sorted_ids_buffer, - .is_updated_view = is_updated_buffer}); + kernels:: + gather_tracks<<>>( + device::gather_tracks_payload{ + .terminate = terminate_device.get(), + .n_accepted = n_accepted_device.get(), + .n_updated_tracks = n_updated_tracks_device.get(), + .temp_sorted_ids_view = temp_sorted_ids_buffer, + .sorted_ids_view = sorted_ids_buffer, + .updated_tracks_view = updated_tracks_buffer, + .is_updated_view = is_updated_buffer}); cudaStreamEndCapture(stream, &graph); cudaGraphInstantiate(&graphExec, graph, nullptr, nullptr, 0); diff --git a/device/cuda/src/ambiguity_resolution/kernels/gather_tracks.cu b/device/cuda/src/ambiguity_resolution/kernels/gather_tracks.cu index 21adc224dd..c3f383949d 100644 --- a/device/cuda/src/ambiguity_resolution/kernels/gather_tracks.cu +++ b/device/cuda/src/ambiguity_resolution/kernels/gather_tracks.cu @@ -23,14 +23,16 @@ __global__ void gather_tracks(device::gather_tracks_payload payload) { vecmem::device_vector temp_sorted_ids( payload.temp_sorted_ids_view); vecmem::device_vector sorted_ids(payload.sorted_ids_view); + vecmem::device_vector updated_tracks( + payload.updated_tracks_view); vecmem::device_vector is_updated(payload.is_updated_view); auto globalIndex = threadIdx.x + blockIdx.x * blockDim.x; const unsigned int n_accepted = *(payload.n_accepted); // Reset is_updated vector - if (globalIndex < is_updated.size()) { - is_updated[globalIndex] = 0; + if (globalIndex < *(payload.n_updated_tracks)) { + is_updated[updated_tracks[globalIndex]] = 0; } if (globalIndex >= n_accepted) {