Skip to content

Commit 31c1199

Browse files
committed
Update CUDA/HIP reduction a tiny bit
1 parent 813a7b0 commit 31c1199

File tree

1 file changed

+19
-10
lines changed

1 file changed

+19
-10
lines changed

algorithms/cudahip/Reduction.cpp

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,20 +56,24 @@ __launch_bounds__(1024) void __global__ kernel_reduce(
5656
const auto threadInWarp = threadIdx.x % warpSize;
5757
const auto warpsNeeded = (size + warpSize - 1) / warpSize;
5858

59+
auto value = operation.defaultValue;
5960
auto acc = operation.defaultValue;
6061

6162
#pragma unroll 4
6263
for (std::size_t i = currentWarp; i < warpsNeeded; i += warpCount) {
6364
const auto id = threadInWarp + i * warpSize;
64-
auto value = (id < size) ? static_cast<AccT>(ntload(&vector[id])) : operation.defaultValue;
65+
const auto valueNew =
66+
(id < size) ? static_cast<AccT>(ntload(&vector[id])) : operation.defaultValue;
6567

66-
for (int offset = 1; offset < warpSize; offset *= 2) {
67-
value = operation(value, shuffledown(value, offset));
68-
}
68+
value = operation(value, valueNew);
69+
}
6970

70-
acc = operation(acc, value);
71+
for (int offset = 1; offset < warpSize; offset *= 2) {
72+
value = operation(value, shuffledown(value, offset));
7173
}
7274

75+
acc = operation(acc, value);
76+
7377
if (threadInWarp == 0) {
7478
shmem[currentWarp] = acc;
7579
}
@@ -78,19 +82,24 @@ __launch_bounds__(1024) void __global__ kernel_reduce(
7882

7983
if (currentWarp == 0) {
8084
const auto lastWarpsNeeded = (warpCount + warpSize - 1) / warpSize;
85+
86+
auto value = operation.defaultValue;
8187
auto lastAcc = operation.defaultValue;
88+
8289
#pragma unroll 2
8390
for (int i = 0; i < lastWarpsNeeded; ++i) {
8491
const auto id = threadInWarp + i * warpSize;
85-
auto value = (id < warpCount) ? shmem[id] : operation.defaultValue;
92+
const auto valueNew = (id < warpCount) ? shmem[id] : operation.defaultValue;
8693

87-
for (int offset = 1; offset < warpSize; offset *= 2) {
88-
value = operation(value, shuffledown(value, offset));
89-
}
94+
value = operation(value, valueNew);
95+
}
9096

91-
lastAcc = operation(lastAcc, value);
97+
for (int offset = 1; offset < warpSize; offset *= 2) {
98+
value = operation(value, shuffledown(value, offset));
9299
}
93100

101+
lastAcc = operation(lastAcc, value);
102+
94103
if (threadIdx.x == 0) {
95104
if (overrideResult) {
96105
ntstore(result, lastAcc);

0 commit comments

Comments
 (0)