Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 63 additions & 46 deletions device/cuda/src/ambiguity_resolution/kernels/sort_updated_tracks.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,17 @@ namespace traccc::cuda::kernels {
__launch_bounds__(512) __global__
void sort_updated_tracks(device::sort_updated_tracks_payload payload) {

if (*(payload.terminate) == 1 || *(payload.n_updated_tracks) == 0) {
const unsigned int n_updated = *(payload.n_updated_tracks);

if (*(payload.terminate) == 1 || n_updated == 0 || n_updated == 1) {
return;
}

__shared__ unsigned int shared_mem_tracks[512];
// Shared: track id + keys cached once (no global reads inside compare
// loops)
__shared__ unsigned int sh_trk[512];
__shared__ traccc::scalar sh_rel[512];
__shared__ traccc::scalar sh_pval[512];

vecmem::device_vector<const traccc::scalar> rel_shared(
payload.rel_shared_view);
Expand All @@ -32,66 +38,77 @@ __launch_bounds__(512) __global__

const unsigned int tid = threadIdx.x;

// Load to shared memory
shared_mem_tracks[tid] = std::numeric_limits<unsigned int>::max();
// Padding the number of tracks to the power of 2
const unsigned int N = 1 << (32 - __clz(n_updated - 1));

// Sentinel keys to push to the end
const unsigned int TRK_SENT = std::numeric_limits<unsigned int>::max();
const traccc::scalar REL_INF =
std::numeric_limits<traccc::scalar>::infinity();
const traccc::scalar PVAL_MIN = traccc::scalar(0);

// Load once to shared (coalesced on updated_tracks)
if (tid < n_updated) {
unsigned int trk = updated_tracks[tid];
sh_trk[tid] = trk;
sh_rel[tid] = rel_shared[trk];
sh_pval[tid] = pvals[trk];
} else {
sh_trk[tid] = TRK_SENT;
sh_rel[tid] =
REL_INF; // bigger rel → goes to the end for ascending rel
sh_pval[tid] = PVAL_MIN; // tie-breaker doesn't matter once rel=INF
}

if (tid < *(payload.n_updated_tracks)) {
shared_mem_tracks[tid] = updated_tracks[tid];
// For any threads beyond N, still need to participate in barriers,
// but give them sentinel content so they don't affect ordering.
if (tid >= N) {
sh_trk[tid] = TRK_SENT;
sh_rel[tid] = REL_INF;
sh_pval[tid] = PVAL_MIN;
}

__syncthreads();

// Padding the number of tracks to the power of 2
const unsigned int N = 1 << (32 - __clz(*(payload.n_updated_tracks) - 1));

traccc::scalar rel_i;
traccc::scalar rel_j;
traccc::scalar pval_i;
traccc::scalar pval_j;

// Bitonic sort
for (int k = 2; k <= N; k <<= 1) {
// Bitonic sort over shared data only
for (unsigned int k = 2; k <= N; k <<= 1) {

bool ascending = ((tid & k) == 0);
const bool ascending = ((tid & k) == 0);

for (int j = k >> 1; j > 0; j >>= 1) {
int ixj = tid ^ j;
for (unsigned int j = k >> 1; j > 0; j >>= 1) {
const unsigned int ixj = tid ^ j;

if (ixj > tid && ixj < N && tid < N) {
unsigned int trk_i = shared_mem_tracks[tid];
unsigned int trk_j = shared_mem_tracks[ixj];

if (trk_i == std::numeric_limits<unsigned int>::max()) {
rel_i = std::numeric_limits<traccc::scalar>::max();
pval_i = 0.f;
} else {
rel_i = rel_shared[trk_i];
pval_i = pvals[trk_i];
}

if (trk_j == std::numeric_limits<unsigned int>::max()) {
rel_j = std::numeric_limits<traccc::scalar>::max();
pval_j = 0.f;
} else {
rel_j = rel_shared[trk_j];
pval_j = pvals[trk_j];
}

bool should_swap =
(rel_i > rel_j || (rel_i == rel_j && pval_i < pval_j)) ==
ascending;

// Load to registers
unsigned int trk_i = sh_trk[tid];
unsigned int trk_j = sh_trk[ixj];
traccc::scalar rel_i = sh_rel[tid];
traccc::scalar rel_j = sh_rel[ixj];
traccc::scalar pval_i = sh_pval[tid];
traccc::scalar pval_j = sh_pval[ixj];

// Compare: ascending by rel, and for equal rel, descending by
// pval
const bool greater =
(rel_i > rel_j) || ((rel_i == rel_j) && (pval_i < pval_j));

const bool should_swap = (greater == ascending);
if (should_swap) {
shared_mem_tracks[tid] = trk_j;
shared_mem_tracks[ixj] = trk_i;
// swap triad
sh_trk[tid] = trk_j;
sh_trk[ixj] = trk_i;
sh_rel[tid] = rel_j;
sh_rel[ixj] = rel_i;
sh_pval[tid] = pval_j;
sh_pval[ixj] = pval_i;
}
}
__syncthreads();
}
}

if (tid < *(payload.n_updated_tracks)) {
updated_tracks[tid] = shared_mem_tracks[tid];
if (tid < n_updated) {
updated_tracks[tid] = sh_trk[tid];
}
}

Expand Down
Loading