Skip to content

Commit 9457b4f

Browse files
authored
Optimize rearrange-tracks (#1122)
1 parent ca5971a commit 9457b4f

File tree

3 files changed

+126
-81
lines changed

3 files changed

+126
-81
lines changed

device/cuda/src/ambiguity_resolution/greedy_ambiguity_resolution_algorithm.cu

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,11 @@ greedy_ambiguity_resolution_algorithm::operator()(
415415
unsigned int nThreads_full = 1024;
416416
unsigned int nBlocks_full = (n_tracks + 1023) / 1024;
417417

418+
unsigned int nThreads_rearrange = 1024;
419+
unsigned int nBlocks_rearrange =
420+
(n_accepted + (nThreads_rearrange / kernels::nThreads_per_track) - 1) /
421+
(nThreads_rearrange / kernels::nThreads_per_track);
422+
418423
// Compute the threadblock dimension for scanning kernels
419424
auto compute_scan_config = [&](unsigned int n_accepted) {
420425
unsigned int nThreads_scan = m_warp_size * 4;
@@ -457,6 +462,10 @@ greedy_ambiguity_resolution_algorithm::operator()(
457462
scan_dim = compute_scan_config(n_accepted);
458463
nThreads_scan = scan_dim.first;
459464
nBlocks_scan = scan_dim.second;
465+
nBlocks_rearrange =
466+
(n_accepted + (nThreads_rearrange / kernels::nThreads_per_track) -
467+
1) /
468+
(nThreads_rearrange / kernels::nThreads_per_track);
460469

461470
// Make CUDA Graph
462471
cudaGraph_t graph;
@@ -565,7 +574,7 @@ greedy_ambiguity_resolution_algorithm::operator()(
565574
.block_offsets_view = scanned_block_offsets_buffer,
566575
.prefix_sums_view = prefix_sums_buffer});
567576

568-
kernels::rearrange_tracks<<<nBlocks_adaptive, nThreads_adaptive, 0,
577+
kernels::rearrange_tracks<<<nBlocks_rearrange, nThreads_rearrange, 0,
569578
stream>>>(device::rearrange_tracks_payload{
570579
.sorted_ids_view = sorted_ids_buffer,
571580
.inverted_ids_view = inverted_ids_buffer,

device/cuda/src/ambiguity_resolution/kernels/rearrange_tracks.cu

Lines changed: 113 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -41,17 +41,28 @@ TRACCC_DEVICE inline bool find_valid_index(
4141
return false;
4242
}
4343

44-
__global__ void rearrange_tracks(device::rearrange_tracks_payload payload) {
44+
__launch_bounds__(1024) __global__
45+
void rearrange_tracks(device::rearrange_tracks_payload payload) {
4546

4647
if (*(payload.terminate) == 1 || *(payload.n_updated_tracks) == 0) {
4748
return;
4849
}
4950

50-
auto gid = threadIdx.x + blockIdx.x * blockDim.x;
51+
auto gid = threadIdx.x / nThreads_per_track +
52+
blockIdx.x * (blockDim.x / nThreads_per_track);
5153
const unsigned int n_accepted = *(payload.n_accepted);
5254

53-
if (gid >= n_accepted) {
54-
return;
55+
auto N = *(payload.n_updated_tracks);
56+
57+
int neff_threads = (N + nThreads_per_track - 1) / nThreads_per_track;
58+
59+
if (neff_threads > nThreads_per_track) {
60+
neff_threads = nThreads_per_track;
61+
}
62+
63+
bool is_valid_thread = true;
64+
if (threadIdx.x % nThreads_per_track >= neff_threads || gid >= n_accepted) {
65+
is_valid_thread = false;
5566
}
5667

5768
vecmem::device_vector<const unsigned int> sorted_ids(
@@ -68,116 +79,139 @@ __global__ void rearrange_tracks(device::rearrange_tracks_payload payload) {
6879
vecmem::device_vector<unsigned int> temp_sorted_ids(
6980
payload.temp_sorted_ids_view);
7081

71-
const auto tid = sorted_ids[gid];
72-
auto rel_sh_ref = rel_shared[tid];
73-
auto pval_ref = pvals[tid];
74-
int shifted_idx = static_cast<int>(gid);
75-
auto N = *(payload.n_updated_tracks);
82+
__shared__ int shifted_indices[1024];
83+
auto& shifted_idx = shifted_indices[threadIdx.x / nThreads_per_track];
84+
unsigned int tid = std::numeric_limits<unsigned int>::max();
7685

77-
if (is_updated[tid]) {
86+
if (is_valid_thread) {
7887

79-
if (gid > 0) {
88+
tid = sorted_ids[gid];
89+
auto rel_sh_ref = rel_shared[tid];
90+
auto pval_ref = pvals[tid];
8091

81-
unsigned int left = 0;
82-
unsigned int right = gid;
92+
shifted_idx = static_cast<int>(gid);
8393

84-
bool first_iteration = true;
85-
while (right > left) {
94+
int stride = (N + neff_threads - 1) / neff_threads;
8695

87-
const bool find_left =
88-
find_valid_index(left, 0, gid, sorted_ids, is_updated);
96+
int ini_idx = stride * (threadIdx.x % nThreads_per_track);
97+
int fin_idx = std::min(ini_idx + stride, static_cast<int>(N));
8998

90-
if (!find_left) {
91-
break;
92-
}
99+
if (is_updated[tid]) {
93100

94-
const bool find_right =
95-
find_valid_index(right, 0, gid, sorted_ids, is_updated);
101+
if (gid > 0) {
96102

97-
if (!find_right) {
98-
break;
99-
}
103+
unsigned int left = 0;
104+
unsigned int right = gid;
100105

101-
if (first_iteration) {
102-
auto rel_sh = rel_shared[sorted_ids[right]];
103-
auto pval = pvals[sorted_ids[right]];
106+
bool first_iteration = true;
104107

105-
if (rel_sh < rel_sh_ref ||
106-
(rel_sh == rel_sh_ref && pval >= pval_ref)) {
107-
left = gid;
108-
break;
109-
}
110-
}
108+
if (threadIdx.x % nThreads_per_track == 0) {
109+
110+
while (right > left) {
111+
112+
const bool find_left = find_valid_index(
113+
left, 0, gid, sorted_ids, is_updated);
114+
115+
if (!find_left) {
116+
break;
117+
}
118+
119+
const bool find_right = find_valid_index(
120+
right, 0, gid, sorted_ids, is_updated);
121+
122+
if (!find_right) {
123+
break;
124+
}
111125

112-
first_iteration = false;
126+
if (first_iteration) {
127+
const auto right_idx = sorted_ids[right];
128+
auto rel_sh = rel_shared[right_idx];
129+
auto pval = pvals[right_idx];
113130

114-
unsigned int mid = left + (right - left) / 2;
131+
if (rel_sh < rel_sh_ref ||
132+
(rel_sh == rel_sh_ref && pval >= pval_ref)) {
133+
left = gid;
134+
break;
135+
}
136+
}
115137

116-
const bool find_mid = find_valid_index(mid, left, right - 1,
117-
sorted_ids, is_updated);
138+
first_iteration = false;
118139

119-
if (find_mid) {
140+
unsigned int mid = left + (right - left) / 2;
120141

121-
auto rel_sh = rel_shared[sorted_ids[mid]];
122-
auto pval = pvals[sorted_ids[mid]];
142+
const bool find_mid = find_valid_index(
143+
mid, left, right - 1, sorted_ids, is_updated);
123144

124-
if (rel_sh < rel_sh_ref ||
125-
(rel_sh == rel_sh_ref && pval >= pval_ref)) {
145+
if (find_mid) {
126146

127-
left = mid + 1;
128-
} else {
129-
right = mid;
147+
const auto mid_idx = sorted_ids[mid];
148+
auto rel_sh = rel_shared[mid_idx];
149+
auto pval = pvals[mid_idx];
150+
151+
if (rel_sh < rel_sh_ref ||
152+
(rel_sh == rel_sh_ref && pval >= pval_ref)) {
153+
154+
left = mid + 1;
155+
} else {
156+
right = mid;
157+
}
158+
}
130159
}
131-
}
132-
}
133160

134-
int delta = 0;
161+
int delta = delta =
162+
gid - left - (prefix_sums[gid] - prefix_sums[left]);
135163

136-
if (is_updated[sorted_ids[left]]) {
137-
delta = gid - left - (prefix_sums[gid] - prefix_sums[left]);
138-
} else {
139-
delta = gid - left - (prefix_sums[gid] - prefix_sums[left] - 1);
140-
}
164+
if (!is_updated[sorted_ids[left]]) {
165+
delta++;
166+
}
141167

142-
shifted_idx -= delta;
143-
}
168+
atomicAdd(&shifted_idx, -delta);
169+
}
170+
}
144171

145-
for (int i = 0; i < N; i++) {
172+
for (int i = ini_idx; i < fin_idx; i++) {
146173

147-
auto id = updated_tracks[i];
174+
auto id = updated_tracks[i];
148175

149-
if (inverted_ids[id] < gid) {
150-
shifted_idx--;
176+
if (inverted_ids[id] < gid) {
177+
atomicAdd(&shifted_idx, -1);
178+
}
151179
}
152-
}
153180

154-
int offset = 0;
155-
for (int i = 0; i < N; i++) {
156-
if (updated_tracks[i] == tid) {
157-
offset = i;
158-
break;
181+
int offset = 0;
182+
for (int i = ini_idx; i < fin_idx; i++) {
183+
if (updated_tracks[i] == tid) {
184+
offset = i;
185+
break;
186+
}
159187
}
160-
}
161-
shifted_idx += offset;
188+
if (offset != 0) {
189+
atomicAdd(&shifted_idx, offset);
190+
}
191+
} else {
162192

163-
} else {
164-
for (int i = 0; i < N; i++) {
193+
for (int i = ini_idx; i < fin_idx; i++) {
165194

166-
auto id = updated_tracks[i];
167-
auto rel_sh = rel_shared[id];
168-
auto pval = pvals[id];
195+
auto id = updated_tracks[i];
196+
auto rel_sh = rel_shared[id];
197+
auto pval = pvals[id];
169198

170-
if (inverted_ids[id] > gid) {
171-
if (rel_sh < rel_sh_ref) {
172-
shifted_idx++;
173-
} else if (rel_sh == rel_sh_ref && pval > pval_ref) {
174-
shifted_idx++;
199+
if (inverted_ids[id] > gid) {
200+
if (rel_sh < rel_sh_ref) {
201+
atomicAdd(&shifted_idx, 1);
202+
} else if (rel_sh == rel_sh_ref && pval > pval_ref) {
203+
atomicAdd(&shifted_idx, 1);
204+
}
175205
}
176206
}
177207
}
178208
}
179209

180-
temp_sorted_ids.at(shifted_idx) = tid;
210+
__syncthreads();
211+
212+
if (is_valid_thread && (threadIdx.x % nThreads_per_track) == 0) {
213+
temp_sorted_ids.at(shifted_idx) = tid;
214+
}
181215
}
182216

183217
} // namespace traccc::cuda::kernels

device/cuda/src/ambiguity_resolution/kernels/rearrange_tracks.cuh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,7 @@
1212

1313
namespace traccc::cuda::kernels {
1414

15+
constexpr const int nThreads_per_track = 4;
16+
1517
__global__ void rearrange_tracks(device::rearrange_tracks_payload payload);
16-
}
18+
} // namespace traccc::cuda::kernels

0 commit comments

Comments
 (0)