Skip to content

Commit a887f7f

Browse files
authored
Merge pull request #1107 from beomki-yeo/optimize-block-inclusive-scan
Optimize the dimension of scanning kernels
2 parents 9b5e684 + ed00354 commit a887f7f

File tree

3 files changed

+36
-18
lines changed

3 files changed

+36
-18
lines changed

device/cuda/src/ambiguity_resolution/greedy_ambiguity_resolution_algorithm.cu

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -404,8 +404,27 @@ greedy_ambiguity_resolution_algorithm::operator()(
404404
unsigned int nThreads_full = 1024;
405405
unsigned int nBlocks_full = (n_tracks + 1023) / 1024;
406406

407-
unsigned int nThreads_scan = 1024;
408-
unsigned int nBlocks_scan = (n_accepted + 1023) / 1024;
407+
// Compute the threadblock dimension for scanning kernels
408+
auto compute_scan_config = [&](unsigned int n_accepted) {
409+
unsigned int nThreads_scan = m_warp_size * 4;
410+
unsigned int nBlocks_scan =
411+
(n_accepted + nThreads_scan - 1) / nThreads_scan;
412+
413+
while (nThreads_scan <= 1024) {
414+
if (nBlocks_scan > 1024) {
415+
nThreads_scan *= 2;
416+
nBlocks_scan = (n_accepted + nThreads_scan - 1) / nThreads_scan;
417+
} else {
418+
break;
419+
}
420+
}
421+
422+
return std::make_pair(nThreads_scan, nBlocks_scan);
423+
};
424+
425+
auto scan_dim = compute_scan_config(n_accepted);
426+
unsigned int nThreads_scan = scan_dim.first;
427+
unsigned int nBlocks_scan = scan_dim.second;
409428

410429
assert(nBlocks_scan <= 1024 &&
411430
"nBlocks_scan larger than 1024 will cause invalid arguments in "
@@ -423,7 +442,10 @@ greedy_ambiguity_resolution_algorithm::operator()(
423442
nBlocks_adaptive =
424443
(n_accepted + nThreads_adaptive - 1) / nThreads_adaptive;
425444
nBlocks_warp = (n_accepted + nThreads_warp - 1) / nThreads_warp;
426-
nBlocks_scan = (n_accepted + 1023) / 1024;
445+
446+
scan_dim = compute_scan_config(n_accepted);
447+
nThreads_scan = scan_dim.first;
448+
nBlocks_scan = scan_dim.second;
427449

428450
// Make CUDA Graph
429451
cudaGraph_t graph;

device/cuda/src/ambiguity_resolution/kernels/block_inclusive_scan.cu

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@ __global__ void block_inclusive_scan(
3232

3333
auto globalIndex = threadIdx.x + blockIdx.x * blockDim.x;
3434
auto threadIndex = threadIdx.x;
35-
auto blockIndex = blockIdx.x;
36-
auto blockSize = blockDim.x;
3735

3836
const unsigned int n_accepted = *(payload.n_accepted);
3937

@@ -46,7 +44,7 @@ __global__ void block_inclusive_scan(
4644
__syncthreads();
4745

4846
// inclusive scan in shared memory
49-
for (int stride = 1; stride < blockSize; stride *= 2) {
47+
for (int stride = 1; stride < blockDim.x; stride *= 2) {
5048
int val = 0;
5149
if (threadIndex >= stride) {
5250
val = shared_temp[threadIndex - stride];
@@ -64,8 +62,8 @@ __global__ void block_inclusive_scan(
6462

6563
__syncthreads();
6664

67-
if (threadIndex == blockSize - 1) {
68-
block_offsets[blockIndex] = shared_temp[threadIndex];
65+
if (threadIndex == blockDim.x - 1) {
66+
block_offsets[blockIdx.x] = shared_temp[threadIndex];
6967
}
7068
}
7169

device/cuda/src/ambiguity_resolution/kernels/scan_block_offsets.cu

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,22 +27,20 @@ __global__ void scan_block_offsets(device::scan_block_offsets_payload payload) {
2727
vecmem::device_vector<int> scanned_block_offsets(
2828
payload.scanned_block_offsets_view);
2929

30-
int n_blocks = (*(payload.n_accepted) + 1023) / 1024;
31-
32-
__syncthreads();
33-
30+
// The number of blocks in the previous block_inclusive_scan = the nubmer of
31+
// threads of this kernel
32+
int n_blocks_prev = blockDim.x;
3433
auto threadIndex = threadIdx.x;
3534

3635
// 1. Load from global to shared
37-
int value = 0;
38-
if (threadIndex < n_blocks) {
39-
value = block_offsets[threadIndex];
36+
shared_temp[threadIndex] = 0;
37+
if (threadIndex < n_blocks_prev) {
38+
shared_temp[threadIndex] = block_offsets[threadIndex];
4039
}
41-
shared_temp[threadIndex] = value;
4240
__syncthreads();
4341

4442
// 2. Inclusive scan (Hillis-Steele style)
45-
for (int offset = 1; offset < n_blocks; offset *= 2) {
43+
for (int offset = 1; offset < n_blocks_prev; offset *= 2) {
4644
int temp = 0;
4745
if (threadIndex >= offset) {
4846
temp = shared_temp[threadIndex - offset];
@@ -53,7 +51,7 @@ __global__ void scan_block_offsets(device::scan_block_offsets_payload payload) {
5351
}
5452

5553
// 3. Write back
56-
if (threadIndex < n_blocks) {
54+
if (threadIndex < n_blocks_prev) {
5755
scanned_block_offsets[threadIndex] = shared_temp[threadIndex];
5856
}
5957
}

0 commit comments

Comments
 (0)