|
26 | 26 |
|
27 | 27 | namespace traccc::cuda::kernels { |
28 | 28 |
|
| 29 | +__device__ __forceinline__ uint64_t pack_key(uint32_t meas, uint32_t thr) { |
| 30 | + return (uint64_t(meas) << 32) | uint64_t(thr); |
| 31 | +} |
| 32 | +__device__ __forceinline__ void unpack_key(uint64_t k, uint32_t& meas, |
| 33 | + uint32_t& thr) { |
| 34 | + meas = uint32_t(k >> 32); |
| 35 | + thr = uint32_t(k & 0xFFFFFFFFu); |
| 36 | +} |
| 37 | + |
29 | 38 | __device__ void count_tracks(int tid, int* sh_n_meas, int n_tracks, |
30 | 39 | unsigned int& bound, unsigned int& count, |
31 | 40 | bool& stop) { |
@@ -81,6 +90,7 @@ __launch_bounds__(512) __global__ |
81 | 90 | __shared__ int sh_buffer[512]; |
82 | 91 | __shared__ measurement_id_type sh_meas_ids[512]; |
83 | 92 | __shared__ unsigned int sh_threads[512]; |
| 93 | + __shared__ uint64_t sh_keys[512]; |
84 | 94 | __shared__ unsigned int n_meas_total; |
85 | 95 | __shared__ unsigned int bound; |
86 | 96 | __shared__ unsigned int n_tracks_to_iterate; |
@@ -193,32 +203,58 @@ __launch_bounds__(512) __global__ |
193 | 203 | __syncthreads(); |
194 | 204 |
|
195 | 205 | const auto tid = threadIndex; |
| 206 | + // No early return: out-of-range threads carry a sentinel and only |
| 207 | + // sync/shuffle. |
| 208 | + uint64_t key = (tid < N) ? pack_key(sh_meas_ids[tid], sh_threads[tid]) |
| 209 | + : 0xFFFFFFFFFFFFFFFFull; // sentinel that won't |
| 210 | + // affect in-range items |
| 211 | + |
196 | 212 | for (int k = 2; k <= N; k <<= 1) { |
| 213 | + // Inter-warp (j >= 32): use shared + barriers |
| 214 | + for (int j = (k >> 1); j >= warpSize; j >>= 1) { |
| 215 | + sh_keys[tid] = key; // safe: sh_keys sized to blockDim.x |
| 216 | + __syncthreads(); |
197 | 217 |
|
198 | | - bool ascending = ((tid & k) == 0); |
| 218 | + const int ixj = tid ^ j; |
| 219 | + // If partner is out-of-range, compare with self (no change). |
| 220 | + uint64_t other = (ixj < N) ? sh_keys[ixj] : key; |
199 | 221 |
|
200 | | - for (int j = k >> 1; j > 0; j >>= 1) { |
201 | | - int ixj = tid ^ j; |
| 222 | + const bool dir = ((tid & k) == 0); // ascending segment? |
| 223 | + const bool lower = ((tid & j) == 0); // am I lower index? |
202 | 224 |
|
203 | | - if (ixj > tid && ixj < N && tid < N) { |
204 | | - auto meas_i = sh_meas_ids[tid]; |
205 | | - auto meas_j = sh_meas_ids[ixj]; |
206 | | - auto thread_i = sh_threads[tid]; |
207 | | - auto thread_j = sh_threads[ixj]; |
| 225 | + const uint64_t mn = (key < other) ? key : other; |
| 226 | + const uint64_t mx = (key < other) ? other : key; |
208 | 227 |
|
209 | | - bool should_swap = |
210 | | - (meas_i > meas_j || |
211 | | - (meas_i == meas_j && thread_i > thread_j)) == ascending; |
| 228 | + key = dir ? (lower ? mn : mx) : (lower ? mx : mn); |
212 | 229 |
|
213 | | - if (should_swap) { |
214 | | - sh_meas_ids[tid] = meas_j; |
215 | | - sh_meas_ids[ixj] = meas_i; |
216 | | - sh_threads[tid] = thread_j; |
217 | | - sh_threads[ixj] = thread_i; |
218 | | - } |
219 | | - } |
220 | 230 | __syncthreads(); |
221 | 231 | } |
| 232 | + |
| 233 | + // Intra-warp (j < 32): warp shuffles only; no barriers |
| 234 | + for (int j = min(k >> 1, warpSize >> 1); j > 0; j >>= 1) { |
| 235 | + const unsigned mask = 0xFFFFFFFFu; |
| 236 | + uint64_t other = __shfl_xor_sync(mask, key, j); |
| 237 | + |
| 238 | + const bool dir = ((tid & k) == 0); |
| 239 | + const bool lower = ((tid & j) == 0); |
| 240 | + |
| 241 | + const uint64_t mn = (key < other) ? key : other; |
| 242 | + const uint64_t mx = (key < other) ? other : key; |
| 243 | + |
| 244 | + key = dir ? (lower ? mn : mx) : (lower ? mx : mn); |
| 245 | + } |
| 246 | + |
| 247 | + // Commit for next inter-warp round visibility |
| 248 | + sh_keys[tid] = key; |
| 249 | + __syncthreads(); |
| 250 | + } |
| 251 | + |
| 252 | + // Write back only in-range threads |
| 253 | + if (tid < N) { |
| 254 | + uint32_t meas, thr; |
| 255 | + unpack_key(key, meas, thr); |
| 256 | + sh_meas_ids[tid] = meas; |
| 257 | + sh_threads[tid] = thr; |
222 | 258 | } |
223 | 259 |
|
224 | 260 | // Find starting point |
|
0 commit comments