@@ -56,20 +56,24 @@ __launch_bounds__(1024) void __global__ kernel_reduce(
5656 const auto threadInWarp = threadIdx.x % warpSize;
5757 const auto warpsNeeded = (size + warpSize - 1 ) / warpSize;
5858
59+ auto value = operation.defaultValue ;
5960 auto acc = operation.defaultValue ;
6061
6162#pragma unroll 4
6263 for (std::size_t i = currentWarp; i < warpsNeeded; i += warpCount) {
6364 const auto id = threadInWarp + i * warpSize;
64- auto value = (id < size) ? static_cast <AccT>(ntload (&vector[id])) : operation.defaultValue ;
65+ const auto valueNew =
66+ (id < size) ? static_cast <AccT>(ntload (&vector[id])) : operation.defaultValue ;
6567
66- for (int offset = 1 ; offset < warpSize; offset *= 2 ) {
67- value = operation (value, shuffledown (value, offset));
68- }
68+ value = operation (value, valueNew);
69+ }
6970
70- acc = operation (acc, value);
71+ for (int offset = 1 ; offset < warpSize; offset *= 2 ) {
72+ value = operation (value, shuffledown (value, offset));
7173 }
7274
75+ acc = operation (acc, value);
76+
7377 if (threadInWarp == 0 ) {
7478 shmem[currentWarp] = acc;
7579 }
@@ -78,19 +82,24 @@ __launch_bounds__(1024) void __global__ kernel_reduce(
7882
7983 if (currentWarp == 0 ) {
8084 const auto lastWarpsNeeded = (warpCount + warpSize - 1 ) / warpSize;
85+
86+ auto value = operation.defaultValue ;
8187 auto lastAcc = operation.defaultValue ;
88+
8289#pragma unroll 2
8390 for (int i = 0 ; i < lastWarpsNeeded; ++i) {
8491 const auto id = threadInWarp + i * warpSize;
85- auto value = (id < warpCount) ? shmem[id] : operation.defaultValue ;
92+ const auto valueNew = (id < warpCount) ? shmem[id] : operation.defaultValue ;
8693
87- for (int offset = 1 ; offset < warpSize; offset *= 2 ) {
88- value = operation (value, shuffledown (value, offset));
89- }
94+ value = operation (value, valueNew);
95+ }
9096
91- lastAcc = operation (lastAcc, value);
97+ for (int offset = 1 ; offset < warpSize; offset *= 2 ) {
98+ value = operation (value, shuffledown (value, offset));
9299 }
93100
101+ lastAcc = operation (lastAcc, value);
102+
94103 if (threadIdx.x == 0 ) {
95104 if (overrideResult) {
96105 ntstore (result, lastAcc);
0 commit comments