Skip to content

Commit d8f5b2e

Browse files
committed
Replace the bubble sort with bitonic sort
1 parent 726978a commit d8f5b2e

File tree

3 files changed

+50
-30
lines changed

3 files changed

+50
-30
lines changed

device/cuda/src/ambiguity_resolution/greedy_ambiguity_resolution_algorithm.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -542,8 +542,7 @@ greedy_ambiguity_resolution_algorithm::operator()(
542542
// when the number of updated tracks <= 1024) and might be faster
543543
// with large number of updated tracks
544544

545-
kernels::sort_updated_tracks<<<1, 1024, 1024 * sizeof(unsigned int),
546-
stream>>>(
545+
kernels::sort_updated_tracks<<<1, 512, 0, stream>>>(
547546
device::sort_updated_tracks_payload{
548547
.rel_shared_view = rel_shared_buffer,
549548
.pvals_view = pvals_buffer,

device/cuda/src/ambiguity_resolution/kernels/count_removable_tracks.cu

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,9 @@ __launch_bounds__(512) __global__ void count_removable_tracks(
171171

172172
const auto tid = threadIndex;
173173
for (int k = 2; k <= N; k <<= 1) {
174+
175+
bool ascending = ((tid & k) == 0);
176+
174177
for (int j = k >> 1; j > 0; j >>= 1) {
175178
int ixj = tid ^ j;
176179

@@ -180,7 +183,6 @@ __launch_bounds__(512) __global__ void count_removable_tracks(
180183
auto thread_i = sh_threads[tid];
181184
auto thread_j = sh_threads[ixj];
182185

183-
bool ascending = ((tid & k) == 0);
184186
bool should_swap =
185187
(meas_i > meas_j ||
186188
(meas_i == meas_j && thread_i > thread_j)) == ascending;

device/cuda/src/ambiguity_resolution/kernels/sort_updated_tracks.cu

Lines changed: 46 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,63 +15,82 @@
1515

1616
namespace traccc::cuda::kernels {
1717

18-
__global__ void sort_updated_tracks(
19-
device::sort_updated_tracks_payload payload) {
18+
__launch_bounds__(512) __global__
19+
void sort_updated_tracks(device::sort_updated_tracks_payload payload) {
2020

2121
if (*(payload.terminate) == 1 || *(payload.n_updated_tracks) == 0) {
2222
return;
2323
}
2424

25-
extern __shared__ unsigned int shared_mem_tracks[];
25+
__shared__ unsigned int shared_mem_tracks[512];
2626

2727
vecmem::device_vector<const traccc::scalar> rel_shared(
2828
payload.rel_shared_view);
2929
vecmem::device_vector<const traccc::scalar> pvals(payload.pvals_view);
3030
vecmem::device_vector<unsigned int> updated_tracks(
3131
payload.updated_tracks_view);
3232

33-
const unsigned int tid = threadIdx.x + blockIdx.x * blockDim.x;
34-
const unsigned int N = *(payload.n_updated_tracks);
33+
const unsigned int tid = threadIdx.x;
3534

3635
// Load to shared memory
37-
if (tid < N) {
36+
shared_mem_tracks[tid] = std::numeric_limits<unsigned int>::max();
37+
38+
if (tid < *(payload.n_updated_tracks)) {
3839
shared_mem_tracks[tid] = updated_tracks[tid];
3940
}
4041

4142
__syncthreads();
4243

43-
for (int iter = 0; iter < N; ++iter) {
44-
bool is_even = (iter % 2 == 0);
45-
int i = tid;
44+
// Padding the number of tracks to the power of 2
45+
const unsigned int N = 1 << (32 - __clz(*(payload.n_updated_tracks) - 1));
46+
47+
traccc::scalar rel_i;
48+
traccc::scalar rel_j;
49+
traccc::scalar pval_i;
50+
traccc::scalar pval_j;
51+
52+
// Bitonic sort
53+
for (int k = 2; k <= N; k <<= 1) {
54+
55+
bool ascending = ((tid & k) == 0);
4656

47-
if (i < N / 2) {
48-
int idx = 2 * i + (is_even ? 0 : 1);
49-
if (idx + 1 < N) {
50-
unsigned int a = shared_mem_tracks[idx];
51-
unsigned int b = shared_mem_tracks[idx + 1];
57+
for (int j = k >> 1; j > 0; j >>= 1) {
58+
int ixj = tid ^ j;
5259

53-
traccc::scalar rel_a = rel_shared[a];
54-
traccc::scalar rel_b = rel_shared[b];
55-
traccc::scalar pv_a = pvals[a];
56-
traccc::scalar pv_b = pvals[b];
60+
if (ixj > tid && ixj < N && tid < N) {
61+
unsigned int trk_i = shared_mem_tracks[tid];
62+
unsigned int trk_j = shared_mem_tracks[ixj];
5763

58-
bool swap = false;
59-
if (rel_a != rel_b) {
60-
swap = rel_a > rel_b;
64+
if (trk_i == std::numeric_limits<unsigned int>::max()) {
65+
rel_i = std::numeric_limits<traccc::scalar>::max();
66+
pval_i = 0.f;
6167
} else {
62-
swap = pv_a < pv_b;
68+
rel_i = rel_shared[trk_i];
69+
pval_i = pvals[trk_i];
6370
}
6471

65-
if (swap) {
66-
shared_mem_tracks[idx] = b;
67-
shared_mem_tracks[idx + 1] = a;
72+
if (trk_j == std::numeric_limits<unsigned int>::max()) {
73+
rel_j = std::numeric_limits<traccc::scalar>::max();
74+
pval_j = 0.f;
75+
} else {
76+
rel_j = rel_shared[trk_j];
77+
pval_j = pvals[trk_j];
78+
}
79+
80+
bool should_swap =
81+
(rel_i > rel_j || (rel_i == rel_j && pval_i < pval_j)) ==
82+
ascending;
83+
84+
if (should_swap) {
85+
shared_mem_tracks[tid] = trk_j;
86+
shared_mem_tracks[ixj] = trk_i;
6887
}
6988
}
89+
__syncthreads();
7090
}
71-
__syncthreads();
7291
}
7392

74-
if (tid < N) {
93+
if (tid < *(payload.n_updated_tracks)) {
7594
updated_tracks[tid] = shared_mem_tracks[tid];
7695
}
7796
}

0 commit comments

Comments
 (0)