Skip to content

Commit a6d9ed2

Browse files
committed
Optimize the bitonic sort with warp-level primitives
1 parent c24dfd2 commit a6d9ed2

File tree

1 file changed

+54
-18
lines changed

1 file changed

+54
-18
lines changed

device/cuda/src/ambiguity_resolution/kernels/remove_tracks.cu

Lines changed: 54 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,15 @@
2626

2727
namespace traccc::cuda::kernels {
2828

29+
__device__ __forceinline__ uint64_t pack_key(uint32_t meas, uint32_t thr) {
30+
return (uint64_t(meas) << 32) | uint64_t(thr);
31+
}
32+
__device__ __forceinline__ void unpack_key(uint64_t k, uint32_t& meas,
33+
uint32_t& thr) {
34+
meas = uint32_t(k >> 32);
35+
thr = uint32_t(k & 0xFFFFFFFFu);
36+
}
37+
2938
__device__ void count_tracks(int tid, int* sh_n_meas, int n_tracks,
3039
unsigned int& bound, unsigned int& count,
3140
bool& stop) {
@@ -81,6 +90,7 @@ __launch_bounds__(512) __global__
8190
__shared__ int sh_buffer[512];
8291
__shared__ measurement_id_type sh_meas_ids[512];
8392
__shared__ unsigned int sh_threads[512];
93+
__shared__ uint64_t sh_keys[512];
8494
__shared__ unsigned int n_meas_total;
8595
__shared__ unsigned int bound;
8696
__shared__ unsigned int n_tracks_to_iterate;
@@ -193,32 +203,58 @@ __launch_bounds__(512) __global__
193203
__syncthreads();
194204

195205
const auto tid = threadIndex;
206+
// No early return: out-of-range threads carry a sentinel and only
207+
// sync/shuffle.
208+
uint64_t key = (tid < N) ? pack_key(sh_meas_ids[tid], sh_threads[tid])
209+
: 0xFFFFFFFFFFFFFFFFull; // sentinel that won't
210+
// affect in-range items
211+
196212
for (int k = 2; k <= N; k <<= 1) {
213+
// Inter-warp (j >= 32): use shared + barriers
214+
for (int j = (k >> 1); j >= warpSize; j >>= 1) {
215+
sh_keys[tid] = key; // safe: sh_keys sized to blockDim.x
216+
__syncthreads();
197217

198-
bool ascending = ((tid & k) == 0);
218+
const int ixj = tid ^ j;
219+
// If partner is out-of-range, compare with self (no change).
220+
uint64_t other = (ixj < N) ? sh_keys[ixj] : key;
199221

200-
for (int j = k >> 1; j > 0; j >>= 1) {
201-
int ixj = tid ^ j;
222+
const bool dir = ((tid & k) == 0); // ascending segment?
223+
const bool lower = ((tid & j) == 0); // am I lower index?
202224

203-
if (ixj > tid && ixj < N && tid < N) {
204-
auto meas_i = sh_meas_ids[tid];
205-
auto meas_j = sh_meas_ids[ixj];
206-
auto thread_i = sh_threads[tid];
207-
auto thread_j = sh_threads[ixj];
225+
const uint64_t mn = (key < other) ? key : other;
226+
const uint64_t mx = (key < other) ? other : key;
208227

209-
bool should_swap =
210-
(meas_i > meas_j ||
211-
(meas_i == meas_j && thread_i > thread_j)) == ascending;
228+
key = dir ? (lower ? mn : mx) : (lower ? mx : mn);
212229

213-
if (should_swap) {
214-
sh_meas_ids[tid] = meas_j;
215-
sh_meas_ids[ixj] = meas_i;
216-
sh_threads[tid] = thread_j;
217-
sh_threads[ixj] = thread_i;
218-
}
219-
}
220230
__syncthreads();
221231
}
232+
233+
// Intra-warp (j < 32): warp shuffles only; no barriers
234+
for (int j = min(k >> 1, warpSize >> 1); j > 0; j >>= 1) {
235+
const unsigned mask = 0xFFFFFFFFu;
236+
uint64_t other = __shfl_xor_sync(mask, key, j);
237+
238+
const bool dir = ((tid & k) == 0);
239+
const bool lower = ((tid & j) == 0);
240+
241+
const uint64_t mn = (key < other) ? key : other;
242+
const uint64_t mx = (key < other) ? other : key;
243+
244+
key = dir ? (lower ? mn : mx) : (lower ? mx : mn);
245+
}
246+
247+
// Commit for next inter-warp round visibility
248+
sh_keys[tid] = key;
249+
__syncthreads();
250+
}
251+
252+
// Write back only in-range threads
253+
if (tid < N) {
254+
uint32_t meas, thr;
255+
unpack_key(key, meas, thr);
256+
sh_meas_ids[tid] = meas;
257+
sh_threads[tid] = thr;
222258
}
223259

224260
// Find starting point

0 commit comments

Comments
 (0)