Skip to content

Commit 9f4d4cb

Browse files
committed
Add manual reduce
1 parent 377769e commit 9f4d4cb

File tree

1 file changed

+100
-0
lines changed

1 file changed

+100
-0
lines changed

algorithms/sycl/Reduction.cpp

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,106 @@
77
#include "utils/logger.h"
88
#include <device.h>
99

10+
#include <limits>
1011
#include <sycl/sycl.hpp>
1112

13+
#if 1
14+
15+
namespace {
16+
using namespace device;
17+
18+
template<ReductionType Type, typename T>
19+
constexpr T neutral() {
20+
if constexpr(Type == ReductionType::Add) {
21+
return T(0);
22+
}
23+
if constexpr(Type == ReductionType::Max) {
24+
return std::numeric_limits<T>::min();
25+
}
26+
if constexpr(Type == ReductionType::Min) {
27+
return std::numeric_limits<T>::max();
28+
}
29+
return T(0);
30+
}
31+
32+
template <ReductionType Type, typename AccT, typename VecT, typename OpT> void launchReduction(AccT* result, const VecT *buffer, size_t size, OpT operation, bool overrideResult, void* streamPtr) {
33+
34+
constexpr auto DefaultValue = neutral<Type, AccT>();
35+
36+
((sycl::queue *) streamPtr)->submit([&](sycl::handler &cgh) {
37+
sycl::local_accessor<AccT, 1> shmem(256, cgh);
38+
cgh.parallel_for(sycl::nd_range<1> { 1024, 1024 },
39+
[=](sycl::nd_item<1> idx) {
40+
const auto subgroup = idx.get_sub_group();
41+
const auto sgSize = subgroup.get_local_range();
42+
43+
const auto warpCount = subgroup.get_group_range();
44+
const auto currentWarp = subgroup.get_group_id();
45+
const auto threadInWarp = subgroup.get_local_id();
46+
const auto warpsNeeded = (size + sgSize - 1) / sgSize;
47+
48+
auto acc = DefaultValue;
49+
50+
#pragma unroll 4
51+
for (std::size_t i = currentWarp; i < warpsNeeded; i += warpCount) {
52+
const auto id = threadInWarp + i * sgSize;
53+
auto value = (id < size) ? static_cast<AccT>(ntload(&buffer[id])) : DefaultValue;
54+
55+
value = sycl::reduce_over_group(subgroup, value, operation);
56+
57+
acc = operation(acc, value);
58+
}
59+
60+
if (threadInWarp == 0) {
61+
shmem[currentWarp] = acc;
62+
}
63+
64+
idx.barrier();
65+
66+
if (currentWarp == 0) {
67+
const auto lastWarpsNeeded = (warpCount + sgSize - 1) / sgSize;
68+
auto lastAcc = DefaultValue;
69+
#pragma unroll 2
70+
for (int i = 0; i < lastWarpsNeeded; ++i) {
71+
const auto id = threadInWarp + i * sgSize;
72+
auto value = (id < warpCount) ? shmem[id] : DefaultValue;
73+
74+
value = sycl::reduce_over_group(subgroup, value, operation);
75+
76+
lastAcc = operation(lastAcc, value);
77+
}
78+
79+
if (threadInWarp == 0) {
80+
if (overrideResult) {
81+
ntstore(result, lastAcc);
82+
}
83+
else {
84+
ntstore(result, operation(ntload(result), lastAcc));
85+
}
86+
}
87+
}
88+
});
89+
});
90+
}
91+
}
92+
93+
namespace device {
94+
template <typename AccT, typename VecT> void Algorithms::reduceVector(AccT* result, const VecT *buffer, bool overrideResult, size_t size, ReductionType type, void* streamPtr) {
95+
switch (type) {
96+
case ReductionType::Add: {
97+
return launchReduction<ReductionType::Add>(result, buffer, size, sycl::plus<AccT>(), overrideResult, streamPtr);
98+
}
99+
case ReductionType::Max: {
100+
return launchReduction<ReductionType::Max>(result, buffer, size, sycl::maximum<AccT>(), overrideResult, streamPtr);
101+
}
102+
case ReductionType::Min: {
103+
return launchReduction<ReductionType::Min>(result, buffer, size, sycl::minimum<AccT>(), overrideResult, streamPtr);
104+
}
105+
}
106+
}
107+
108+
#else
109+
12110
namespace {
13111
template <typename AccT, typename VecT, typename S> void launchReduction(AccT* result, const VecT *buffer, size_t size, S reducer, void* streamPtr) {
14112
((sycl::queue *) streamPtr)->submit([&](sycl::handler &cgh) {
@@ -43,6 +141,8 @@ template <typename AccT, typename VecT> void Algorithms::reduceVector(AccT* resu
43141
}
44142
}
45143

144+
#endif
145+
46146
template void Algorithms::reduceVector(int* result, const int *buffer, bool overrideResult, size_t size, ReductionType type, void* streamPtr);
47147
template void Algorithms::reduceVector(unsigned* result, const unsigned *buffer, bool overrideResult, size_t size, ReductionType type, void* streamPtr);
48148
template void Algorithms::reduceVector(long* result, const int *buffer, bool overrideResult, size_t size, ReductionType type, void* streamPtr);

0 commit comments

Comments
 (0)