Skip to content

Commit f05fa6c

Browse files
authored
Merge pull request #1093 from beomki-yeo/optimize-remove-tracks
Optimize `remove_tracks` by skipping unnecessary iteration
2 parents ca2cfcc + 83edf77 commit f05fa6c

File tree

5 files changed

+28
-34
lines changed

5 files changed

+28
-34
lines changed

device/common/include/traccc/ambiguity_resolution/device/count_removable_tracks.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,11 @@ struct count_removable_tracks_payload {
8686
* @brief View object to thread id of measurements to remove
8787
*/
8888
vecmem::data::vector_view<unsigned int> threads_view;
89+
90+
/**
91+
* @brief The number of threads that can remove its corresponding track
92+
*/
93+
unsigned int* n_valid_threads;
8994
};
9095

9196
} // namespace traccc::device

device/common/include/traccc/ambiguity_resolution/device/remove_tracks.hpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,6 @@ struct remove_tracks_payload {
7979
*/
8080
unsigned int* n_removable_tracks;
8181

82-
/**
83-
* @brief The number of measurements to remove
84-
*/
85-
unsigned int* n_meas_to_remove;
86-
8782
/**
8883
* @brief View object to measurements to remove
8984
*/
@@ -113,6 +108,11 @@ struct remove_tracks_payload {
113108
* @brief View object to the whether track id is updated
114109
*/
115110
vecmem::data::vector_view<int> is_updated_view;
111+
112+
/**
113+
* @brief The number of threads that can remove its corresponding track
114+
*/
115+
unsigned int* n_valid_threads;
116116
};
117117

118118
} // namespace traccc::device

device/cuda/src/ambiguity_resolution/greedy_ambiguity_resolution_algorithm.cu

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,8 @@ greedy_ambiguity_resolution_algorithm::operator()(
375375
vecmem::make_unique_alloc<unsigned int>(m_mr.main);
376376
vecmem::unique_alloc_ptr<unsigned int> n_meas_to_remove_device =
377377
vecmem::make_unique_alloc<unsigned int>(m_mr.main);
378+
vecmem::unique_alloc_ptr<unsigned int> n_valid_threads_device =
379+
vecmem::make_unique_alloc<unsigned int>(m_mr.main);
378380

379381
// Device objects
380382
int is_first_iteration = 1;
@@ -459,9 +461,10 @@ greedy_ambiguity_resolution_algorithm::operator()(
459461
.n_removable_tracks = n_removable_tracks_device.get(),
460462
.n_meas_to_remove = n_meas_to_remove_device.get(),
461463
.meas_to_remove_view = meas_to_remove_buffer,
462-
.threads_view = threads_buffer});
464+
.threads_view = threads_buffer,
465+
.n_valid_threads = n_valid_threads_device.get()});
463466

464-
kernels::remove_tracks<<<1, 1024, 0, stream>>>(
467+
kernels::remove_tracks<<<1, 512, 0, stream>>>(
465468
device::remove_tracks_payload{
466469
.sorted_ids_view = sorted_ids_buffer,
467470
.n_accepted = n_accepted_device.get(),
@@ -476,13 +479,13 @@ greedy_ambiguity_resolution_algorithm::operator()(
476479
.n_shared_view = n_shared_buffer,
477480
.rel_shared_view = rel_shared_buffer,
478481
.n_removable_tracks = n_removable_tracks_device.get(),
479-
.n_meas_to_remove = n_meas_to_remove_device.get(),
480482
.meas_to_remove_view = meas_to_remove_buffer,
481483
.threads_view = threads_buffer,
482484
.terminate = terminate_device.get(),
483485
.n_updated_tracks = n_updated_tracks_device.get(),
484486
.updated_tracks_view = updated_tracks_buffer,
485-
.is_updated_view = is_updated_buffer});
487+
.is_updated_view = is_updated_buffer,
488+
.n_valid_threads = n_valid_threads_device.get()});
486489

487490
// The seven kernels below are to keep sorted_ids sorted based on
488491
// the relative shared measurements and pvalues. This can be reduced

device/cuda/src/ambiguity_resolution/kernels/count_removable_tracks.cu

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ __launch_bounds__(512) __global__ void count_removable_tracks(
102102
if (threadIndex == 0) {
103103
*(payload.n_removable_tracks) = 0;
104104
*(payload.n_meas_to_remove) = 0;
105+
*(payload.n_valid_threads) = 0;
105106
n_meas_total = 0;
106107
bound = 512;
107108
N = 1;
@@ -242,20 +243,12 @@ __launch_bounds__(512) __global__ void count_removable_tracks(
242243

243244
__syncthreads();
244245

245-
auto n_meas_to_remove_temp = *(payload.n_meas_to_remove);
246-
247-
if (threadIndex == 0) {
248-
*(payload.n_meas_to_remove) = 0;
249-
}
250-
251-
__syncthreads();
252-
253246
int is_valid =
254247
(threads[threadIndex] < *(payload.n_removable_tracks)) ? 1 : 0;
255248

256249
// TODO: Use better reduction algorithm
257250
if (is_valid) {
258-
atomicAdd(payload.n_meas_to_remove, 1);
251+
atomicAdd(payload.n_valid_threads, 1);
259252
}
260253

261254
__syncthreads();
@@ -264,7 +257,7 @@ __launch_bounds__(512) __global__ void count_removable_tracks(
264257
prefix[threadIndex] = is_valid; // copy input
265258
__syncthreads();
266259

267-
for (int offset = 1; offset < n_meas_to_remove_temp; offset <<= 1) {
260+
for (int offset = 1; offset < *(payload.n_meas_to_remove); offset <<= 1) {
268261
int val = 0;
269262
if (threadIndex >= offset) {
270263
val = prefix[threadIndex - offset];

device/cuda/src/ambiguity_resolution/kernels/remove_tracks.cu

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,16 @@
2626

2727
namespace traccc::cuda::kernels {
2828

29-
__global__ void remove_tracks(device::remove_tracks_payload payload) {
29+
__launch_bounds__(512) __global__
30+
void remove_tracks(device::remove_tracks_payload payload) {
3031

3132
if (*(payload.terminate) == 1) {
3233
return;
3334
}
3435

35-
__shared__ unsigned int shared_tids[1024];
36-
__shared__ measurement_id_type sh_meas_ids[1024];
37-
__shared__ unsigned int sh_threads[1024];
36+
__shared__ unsigned int shared_tids[512];
37+
__shared__ measurement_id_type sh_meas_ids[512];
38+
__shared__ unsigned int sh_threads[512];
3839

3940
auto threadIndex = threadIdx.x;
4041

@@ -73,7 +74,7 @@ __global__ void remove_tracks(device::remove_tracks_payload payload) {
7374
(*payload.n_accepted) -= *(payload.n_removable_tracks);
7475
}
7576

76-
if (threadIndex < *(payload.n_meas_to_remove)) {
77+
if (threadIndex < *(payload.n_valid_threads)) {
7778
sh_meas_ids[threadIndex] = meas_to_remove[threadIndex];
7879
sh_threads[threadIndex] = threads[threadIndex];
7980
is_valid_thread = true;
@@ -82,16 +83,8 @@ __global__ void remove_tracks(device::remove_tracks_payload payload) {
8283
__syncthreads();
8384

8485
if (is_valid_thread) {
85-
8686
const auto id = sh_meas_ids[threadIndex];
87-
is_duplicate = false;
88-
89-
for (unsigned int i = 0; i < threadIndex; ++i) {
90-
if (sh_meas_ids[i] == id) {
91-
is_duplicate = true;
92-
break;
93-
}
94-
}
87+
is_duplicate = (threadIndex > 0 && sh_meas_ids[threadIndex - 1] == id);
9588
}
9689

9790
bool active = false;
@@ -116,7 +109,7 @@ __global__ void remove_tracks(device::remove_tracks_payload payload) {
116109
track_status[worst_idx] = 0;
117110

118111
int n_sharing_tracks = 1;
119-
for (unsigned int i = threadIndex + 1; i < *(payload.n_meas_to_remove);
112+
for (unsigned int i = threadIndex + 1; i < *(payload.n_valid_threads);
120113
++i) {
121114

122115
if (sh_meas_ids[i] == id && sh_threads[i] != sh_threads[i - 1]) {

0 commit comments

Comments
 (0)