Skip to content

Commit e6f5e34

Browse files
authored
Remove the iterations in deduplication (#1126)
1 parent 2490295 commit e6f5e34

File tree

3 files changed

+30
-23
lines changed

3 files changed

+30
-23
lines changed

device/common/include/traccc/ambiguity_resolution/device/remove_tracks.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,11 @@ struct remove_tracks_payload {
113113
* @brief The number of threads that can remove its corresponding track
114114
*/
115115
unsigned int* n_valid_threads;
116+
117+
/**
118+
* @brief View object to the vector of track count during removal process
119+
*/
120+
vecmem::data::vector_view<int> track_count_view;
116121
};
117122

118123
} // namespace traccc::device

device/cuda/src/ambiguity_resolution/greedy_ambiguity_resolution_algorithm.cu

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,11 @@ greedy_ambiguity_resolution_algorithm::operator()(
365365
vecmem::data::vector_buffer<int> is_updated_buffer{n_tracks, m_mr.main};
366366
m_copy.get().setup(inverted_ids_buffer)->ignore();
367367

368+
// Count track id apperance during removal process
369+
vecmem::data::vector_buffer<int> track_count_buffer{n_tracks, m_mr.main};
370+
m_copy.get().setup(track_count_buffer)->ignore();
371+
m_copy.get().memset(track_count_buffer, 0)->ignore();
372+
368373
// Prefix sum buffer
369374
vecmem::data::vector_buffer<int> prefix_sums_buffer{n_tracks, m_mr.main};
370375
m_copy.get().setup(prefix_sums_buffer)->ignore();
@@ -511,7 +516,8 @@ greedy_ambiguity_resolution_algorithm::operator()(
511516
.n_updated_tracks = n_updated_tracks_device.get(),
512517
.updated_tracks_view = updated_tracks_buffer,
513518
.is_updated_view = is_updated_buffer,
514-
.n_valid_threads = n_valid_threads_device.get()});
519+
.n_valid_threads = n_valid_threads_device.get(),
520+
.track_count_view = track_count_buffer});
515521

516522
// The seven kernels below are to keep sorted_ids sorted based on
517523
// the relative shared measurements and pvalues. This can be reduced

device/cuda/src/ambiguity_resolution/kernels/remove_tracks.cu

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ __launch_bounds__(512) __global__
111111
vecmem::device_vector<unsigned int> updated_tracks(
112112
payload.updated_tracks_view);
113113
vecmem::device_vector<int> is_updated(payload.is_updated_view);
114+
vecmem::device_vector<int> track_count(payload.track_count_view);
114115

115116
if (threadIndex == 0) {
116117
*(payload.n_removable_tracks) = 0;
@@ -342,6 +343,7 @@ __launch_bounds__(512) __global__
342343

343344
bool active = false;
344345
unsigned int pos1;
346+
int alive_trk_id = 0;
345347

346348
if (!is_duplicate && is_valid_thread) {
347349

@@ -397,46 +399,40 @@ __launch_bounds__(512) __global__
397399
track_status.begin();
398400

399401
pos1 = atomicAdd(&n_updating_threads, 1);
402+
alive_trk_id = static_cast<int>(tracks[alive_idx]);
400403

401-
sh_buffer[pos1] = static_cast<int>(tracks[alive_idx]);
404+
sh_buffer[pos1] = alive_trk_id;
405+
atomicAdd(&track_count[alive_trk_id], 1);
402406

403-
auto tid = sh_buffer[pos1];
407+
const auto m_count = static_cast<unsigned int>(
408+
thrust::count(thrust::seq, meas_ids[alive_trk_id].begin(),
409+
meas_ids[alive_trk_id].end(), id));
404410

405-
const auto m_count = static_cast<unsigned int>(thrust::count(
406-
thrust::seq, meas_ids[tid].begin(), meas_ids[tid].end(), id));
407-
408-
const unsigned int N_S =
409-
vecmem::device_atomic_ref<unsigned int>(n_shared.at(tid))
410-
.fetch_sub(m_count);
411+
const unsigned int N_S = vecmem::device_atomic_ref<unsigned int>(
412+
n_shared.at(alive_trk_id))
413+
.fetch_sub(m_count);
411414
}
412415
}
413416

414417
__syncthreads();
415418

416419
if (active) {
417-
auto tid = sh_buffer[pos1];
418-
bool already_pushed = false;
419-
for (unsigned int i = 0; i < pos1; ++i) {
420-
if (sh_buffer[i] == tid) {
421-
already_pushed = true;
422-
break;
423-
}
424-
}
425420

426-
if (!already_pushed) {
421+
auto count = atomicAdd(&track_count[alive_trk_id], -1);
422+
if (count == 1) {
427423

428424
// Write updated track IDs
429425
vecmem::device_atomic_ref<unsigned int> num_updated_tracks(
430426
*(payload.n_updated_tracks));
431427

432428
const unsigned int pos2 = num_updated_tracks.fetch_add(1);
433429

434-
updated_tracks[pos2] = tid;
435-
is_updated[tid] = 1;
430+
updated_tracks[pos2] = alive_trk_id;
431+
is_updated[alive_trk_id] = 1;
436432

437-
rel_shared.at(tid) =
438-
math::div_ieee754(static_cast<traccc::scalar>(n_shared.at(tid)),
439-
static_cast<traccc::scalar>(n_meas.at(tid)));
433+
rel_shared.at(alive_trk_id) = math::div_ieee754(
434+
static_cast<traccc::scalar>(n_shared.at(alive_trk_id)),
435+
static_cast<traccc::scalar>(n_meas.at(alive_trk_id)));
440436
}
441437
}
442438
}

0 commit comments

Comments
 (0)