diff --git a/device/cuda/src/ambiguity_resolution/kernels/sort_updated_tracks.cu b/device/cuda/src/ambiguity_resolution/kernels/sort_updated_tracks.cu index 3ffcedd09e..010e677c02 100644 --- a/device/cuda/src/ambiguity_resolution/kernels/sort_updated_tracks.cu +++ b/device/cuda/src/ambiguity_resolution/kernels/sort_updated_tracks.cu @@ -18,11 +18,17 @@ namespace traccc::cuda::kernels { __launch_bounds__(512) __global__ void sort_updated_tracks(device::sort_updated_tracks_payload payload) { - if (*(payload.terminate) == 1 || *(payload.n_updated_tracks) == 0) { + const unsigned int n_updated = *(payload.n_updated_tracks); + + if (*(payload.terminate) == 1 || n_updated == 0 || n_updated == 1) { return; } - __shared__ unsigned int shared_mem_tracks[512]; + // Shared: track id + keys cached once (no global reads inside compare + // loops) + __shared__ unsigned int sh_trk[512]; + __shared__ traccc::scalar sh_rel[512]; + __shared__ traccc::scalar sh_pval[512]; vecmem::device_vector rel_shared( payload.rel_shared_view); @@ -32,66 +38,77 @@ __launch_bounds__(512) __global__ const unsigned int tid = threadIdx.x; - // Load to shared memory - shared_mem_tracks[tid] = std::numeric_limits::max(); + // Padding the number of tracks to the power of 2 + const unsigned int N = 1 << (32 - __clz(n_updated - 1)); + + // Sentinel keys to push to the end + const unsigned int TRK_SENT = std::numeric_limits::max(); + const traccc::scalar REL_INF = + std::numeric_limits::infinity(); + const traccc::scalar PVAL_MIN = traccc::scalar(0); + + // Load once to shared (coalesced on updated_tracks) + if (tid < n_updated) { + unsigned int trk = updated_tracks[tid]; + sh_trk[tid] = trk; + sh_rel[tid] = rel_shared[trk]; + sh_pval[tid] = pvals[trk]; + } else { + sh_trk[tid] = TRK_SENT; + sh_rel[tid] = + REL_INF; // bigger rel → goes to the end for ascending rel + sh_pval[tid] = PVAL_MIN; // tie-breaker doesn't matter once rel=INF + } - if (tid < *(payload.n_updated_tracks)) { - shared_mem_tracks[tid] = updated_tracks[tid]; + // For any threads beyond N, still need to participate in barriers, + // but give them sentinel content so they don't affect ordering. + if (tid >= N) { + sh_trk[tid] = TRK_SENT; + sh_rel[tid] = REL_INF; + sh_pval[tid] = PVAL_MIN; } __syncthreads(); - // Padding the number of tracks to the power of 2 - const unsigned int N = 1 << (32 - __clz(*(payload.n_updated_tracks) - 1)); - - traccc::scalar rel_i; - traccc::scalar rel_j; - traccc::scalar pval_i; - traccc::scalar pval_j; - - // Bitonic sort - for (int k = 2; k <= N; k <<= 1) { + // Bitonic sort over shared data only + for (unsigned int k = 2; k <= N; k <<= 1) { - bool ascending = ((tid & k) == 0); + const bool ascending = ((tid & k) == 0); - for (int j = k >> 1; j > 0; j >>= 1) { - int ixj = tid ^ j; + for (unsigned int j = k >> 1; j > 0; j >>= 1) { + const unsigned int ixj = tid ^ j; if (ixj > tid && ixj < N && tid < N) { - unsigned int trk_i = shared_mem_tracks[tid]; - unsigned int trk_j = shared_mem_tracks[ixj]; - - if (trk_i == std::numeric_limits::max()) { - rel_i = std::numeric_limits::max(); - pval_i = 0.f; - } else { - rel_i = rel_shared[trk_i]; - pval_i = pvals[trk_i]; - } - - if (trk_j == std::numeric_limits::max()) { - rel_j = std::numeric_limits::max(); - pval_j = 0.f; - } else { - rel_j = rel_shared[trk_j]; - pval_j = pvals[trk_j]; - } - - bool should_swap = - (rel_i > rel_j || (rel_i == rel_j && pval_i < pval_j)) == - ascending; - + // Load to registers + unsigned int trk_i = sh_trk[tid]; + unsigned int trk_j = sh_trk[ixj]; + traccc::scalar rel_i = sh_rel[tid]; + traccc::scalar rel_j = sh_rel[ixj]; + traccc::scalar pval_i = sh_pval[tid]; + traccc::scalar pval_j = sh_pval[ixj]; + + // Compare: ascending by rel, and for equal rel, descending by + // pval + const bool greater = + (rel_i > rel_j) || ((rel_i == rel_j) && (pval_i < pval_j)); + + const bool should_swap = (greater == ascending); if (should_swap) { - shared_mem_tracks[tid] = trk_j; - shared_mem_tracks[ixj] = trk_i; + // swap triad + sh_trk[tid] = trk_j; + sh_trk[ixj] = trk_i; + sh_rel[tid] = rel_j; + sh_rel[ixj] = rel_i; + sh_pval[tid] = pval_j; + sh_pval[ixj] = pval_i; } } __syncthreads(); } } - if (tid < *(payload.n_updated_tracks)) { - updated_tracks[tid] = shared_mem_tracks[tid]; + if (tid < n_updated) { + updated_tracks[tid] = sh_trk[tid]; } }