Skip to content

Commit 327189d

Browse files
authored
Merge pull request #1095 from beomki-yeo/optimize-remove-tracks-3
Reduce the unnecessary iterations in `remove_tracks`
2 parents f05fa6c + 29d475f commit 327189d

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

device/cuda/src/ambiguity_resolution/kernels/remove_tracks.cu

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ __launch_bounds__(512) __global__
3636
__shared__ unsigned int shared_tids[512];
3737
__shared__ measurement_id_type sh_meas_ids[512];
3838
__shared__ unsigned int sh_threads[512];
39+
__shared__ unsigned int n_updating_threads;
3940

4041
auto threadIndex = threadIdx.x;
4142

@@ -72,6 +73,7 @@ __launch_bounds__(512) __global__
7273

7374
if (threadIndex == 0) {
7475
(*payload.n_accepted) -= *(payload.n_removable_tracks);
76+
n_updating_threads = 0;
7577
}
7678

7779
if (threadIndex < *(payload.n_valid_threads)) {
@@ -88,6 +90,7 @@ __launch_bounds__(512) __global__
8890
}
8991

9092
bool active = false;
93+
unsigned int pos1;
9194

9295
if (!is_duplicate && is_valid_thread) {
9396

@@ -141,10 +144,11 @@ __launch_bounds__(512) __global__
141144
track_status.end(), 1) -
142145
track_status.begin();
143146

144-
shared_tids[threadIndex] =
145-
static_cast<unsigned int>(tracks[alive_idx]);
147+
pos1 = atomicAdd(&n_updating_threads, 1);
146148

147-
auto tid = shared_tids[threadIndex];
149+
shared_tids[pos1] = static_cast<unsigned int>(tracks[alive_idx]);
150+
151+
auto tid = shared_tids[pos1];
148152

149153
const auto m_count = static_cast<unsigned int>(thrust::count(
150154
thrust::seq, meas_ids[tid].begin(), meas_ids[tid].end(), id));
@@ -158,23 +162,24 @@ __launch_bounds__(512) __global__
158162
__syncthreads();
159163

160164
if (active) {
161-
auto tid = shared_tids[threadIndex];
165+
auto tid = shared_tids[pos1];
162166
bool already_pushed = false;
163-
for (unsigned int i = 0; i < threadIndex; ++i) {
167+
for (unsigned int i = 0; i < pos1; ++i) {
164168
if (shared_tids[i] == tid) {
165169
already_pushed = true;
166170
break;
167171
}
168172
}
173+
169174
if (!already_pushed) {
170175

171176
// Write updated track IDs
172177
vecmem::device_atomic_ref<unsigned int> num_updated_tracks(
173178
*(payload.n_updated_tracks));
174179

175-
const unsigned int pos = num_updated_tracks.fetch_add(1);
180+
const unsigned int pos2 = num_updated_tracks.fetch_add(1);
176181

177-
updated_tracks[pos] = tid;
182+
updated_tracks[pos2] = tid;
178183
is_updated[tid] = 1;
179184

180185
rel_shared.at(tid) = static_cast<traccc::scalar>(n_shared.at(tid)) /

0 commit comments

Comments
 (0)