Skip to content

Commit b3e82ce

Browse files
authored
Merge pull request #1141 from beomki-yeo/optimize-count-tracks-2
Split the `count_tracks` function into warp-level and block-level phases
2 parents 1241c09 + 253348a commit b3e82ce

File tree

1 file changed

+44
-38
lines changed

1 file changed

+44
-38
lines changed

device/cuda/src/ambiguity_resolution/kernels/remove_tracks.cu

Lines changed: 44 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -36,36 +36,62 @@ __device__ __forceinline__ void unpack_key(uint64_t k, uint32_t& meas,
3636
}
3737

3838
__device__ void count_tracks(int tid, int* sh_n_meas, int n_tracks,
39-
unsigned int& bound, unsigned int& count,
40-
bool& stop) {
39+
unsigned int& bound, unsigned int& count) {
4140

4241
unsigned int add = 0;
43-
unsigned int offset = 0;
44-
for (unsigned int stride = 1; stride < (n_tracks - count); stride *= 2) {
45-
if ((count + tid + stride) < n_tracks) {
46-
sh_n_meas[count + tid] += sh_n_meas[count + tid + stride];
42+
43+
// --- Warp-level phase: handle strides < 32 using warp shuffle (no
44+
// __syncthreads needed) ---
45+
const int lane = threadIdx.x & 31;
46+
const unsigned int full_mask = 0xFFFFFFFFu;
47+
48+
// Load this thread's value into a register if it's in range
49+
int v = (tid < n_tracks) ? sh_n_meas[tid] : 0;
50+
51+
// Mask for active lanes in this warp
52+
const unsigned int mask = __ballot_sync(full_mask, tid < n_tracks);
53+
54+
const int max_stride = min(n_tracks, 32);
55+
56+
for (int stride = 1; stride < max_stride; stride <<= 1) {
57+
// Accumulate neighbor's value via warp shuffle
58+
unsigned int other = __shfl_down_sync(mask, v, stride);
59+
if (lane + stride < 32 && (tid + stride) < n_tracks) {
60+
v += other;
4761
}
48-
__syncthreads();
4962

50-
if (sh_n_meas[count] < bound) {
51-
if (tid == 0) {
52-
offset = sh_n_meas[count];
53-
add = stride * 2;
63+
// Thread 0 can directly check its register value in the warp phase
64+
if (tid == 0) {
65+
if (v < bound) {
66+
add = stride << 1;
5467
}
5568
}
69+
}
5670

57-
__syncthreads();
71+
// Write warp-phase result back to shared memory
72+
if (tid < n_tracks) {
73+
sh_n_meas[tid] = static_cast<int>(v);
5874
}
75+
__syncthreads();
5976

60-
if (tid == 0) {
61-
bound -= offset;
62-
count += add;
77+
// --- Block-level phase: handle strides >= 32 (minimal required
78+
// synchronizations) ---
79+
for (int stride = 32; stride < n_tracks; stride <<= 1) {
80+
if ((tid + stride) < n_tracks) {
81+
sh_n_meas[tid] += sh_n_meas[tid + stride];
82+
}
83+
__syncthreads();
6384

64-
if (add == 0) {
65-
stop = true;
85+
if (tid == 0 && sh_n_meas[0] < bound) {
86+
add = stride << 1;
6687
}
88+
__syncthreads();
6789
}
6890

91+
// --- Final update ---
92+
if (tid == 0) {
93+
count += add;
94+
}
6995
__syncthreads();
7096
}
7197

@@ -96,7 +122,6 @@ __launch_bounds__(512) __global__
96122
__shared__ unsigned int n_tracks_to_iterate;
97123
__shared__ unsigned int min_thread;
98124
__shared__ unsigned int N;
99-
__shared__ bool stop;
100125
__shared__ unsigned int n_updating_threads;
101126

102127
auto threadIndex = threadIdx.x;
@@ -135,7 +160,6 @@ __launch_bounds__(512) __global__
135160
N = 1;
136161
n_tracks_to_iterate = 0;
137162
min_thread = std::numeric_limits<unsigned int>::max();
138-
stop = false;
139163
}
140164

141165
__syncthreads();
@@ -158,26 +182,8 @@ __launch_bounds__(512) __global__
158182
* Count the number of removable tracks
159183
****************************************/
160184

161-
// @TODO: Improve the logic
162185
count_tracks(threadIdx.x, sh_buffer, n_tracks_total, bound,
163-
n_tracks_to_iterate, stop);
164-
/*
165-
for (int i = 0; i < 100; i++) {
166-
count_tracks(threadIdx.x, shared_n_meas, n_tracks_total, bound,
167-
n_tracks_to_iterate, stop);
168-
__syncthreads();
169-
if (stop)
170-
break;
171-
172-
if (gid >= 0 && static_cast<unsigned int>(gid) < sorted_ids.size()) {
173-
const auto trk_id = sorted_ids[gid];
174-
if (trk_id < n_meas.size()) {
175-
shared_n_meas[threadIndex] = n_meas[trk_id];
176-
}
177-
}
178-
__syncthreads();
179-
}
180-
*/
186+
n_tracks_to_iterate);
181187

182188
if (threadIndex == 0 && n_tracks_to_iterate == 0) {
183189
n_tracks_to_iterate = 1;

0 commit comments

Comments
 (0)