Skip to content
112 changes: 70 additions & 42 deletions algorithms/sycl/Reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,61 +29,89 @@ namespace {
return T(0);
}

template <ReductionType Type, typename AccT, typename VecT, typename OpT> void launchReduction(AccT* result, const VecT *buffer, size_t size, OpT operation, bool overrideResult, void* streamPtr) {
template<ReductionType Type, typename AtomicRef, typename AccT>
void atomicUpdate(AtomicRef& atomic, AccT value){

constexpr auto DefaultValue = neutral<Type, AccT>();
constexpr auto MO = sycl::memory_order::relaxed;
AccT expected = neutral<Type, AccT>();

((sycl::queue *) streamPtr)->submit([&](sycl::handler &cgh) {
sycl::local_accessor<AccT, 1> shmem(256, cgh);
cgh.parallel_for(sycl::nd_range<1> { 1024, 1024 },
[=](sycl::nd_item<1> idx) {
const auto subgroup = idx.get_sub_group();
const auto sgSize = subgroup.get_local_range().size();
if constexpr(Type == ReductionType::Add) {
// Explicity pass MO to fetch_add
atomic.fetch_add(value, MO);
}
if constexpr(Type == ReductionType::Max) {
// sm 60 does not have a fetch max instruction.
// Using our own CAS loop
// Explicity pass MO to load
// AccT expected = atomic.load(MO);

while(expected < value){
if(atomic.compare_exchange_weak(expected, value, MO, MO)){
break;
}
}
// atomic.fetch_max(value);
}

const auto warpCount = subgroup.get_group_range().size();
const int currentWarp = subgroup.get_group_id();
const int threadInWarp = subgroup.get_local_id();
const auto warpsNeeded = (size + sgSize - 1) / sgSize;
if constexpr(Type == ReductionType::Min) {
//sm 60 does not have a fetch min instruction
// Using our own CAS loop
// AccT expected = atomic.load(MO);

while(expected > value){
if(atomic.compare_exchange_weak(expected, value, MO, MO)){
break;
}
}
// atomic.fetch_min(value);
}
}

auto acc = DefaultValue;

#pragma unroll 4
for (std::size_t i = currentWarp; i < warpsNeeded; i += warpCount) {
const auto id = threadInWarp + i * sgSize;
auto value = (id < size) ? static_cast<AccT>(ntload(&buffer[id])) : DefaultValue;
template <ReductionType Type, typename AccT, typename VecT, typename OpT> void launchReduction(AccT* result, const VecT *buffer, size_t size, OpT operation, bool overrideResult, void* streamPtr) {

value = sycl::reduce_over_group(subgroup, value, operation);
constexpr auto DefaultValue = neutral<Type, AccT>();
constexpr size_t workGroupSize = 256;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why 256 and not 1024? (or are you at the bandwidth already like that?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, and for higher numbers, there is no real improvement on PVC at least.

constexpr size_t itemsPerWorkItem = 8;

acc = operation(acc, value);
}
if(overrideResult){
((sycl::queue *) streamPtr)->submit([&](sycl::handler &cgh) {
cgh.single_task([=](){
// Initialize the global result to identity value.
*result = DefaultValue;
});
});
}

if (threadInWarp == 0) {
shmem[currentWarp] = acc;
}
((sycl::queue *) streamPtr)->submit([&](sycl::handler &cgh) {

idx.barrier();
const size_t numWorkGroups = (size + (workGroupSize * itemsPerWorkItem) - 1)
/ (workGroupSize * itemsPerWorkItem);

if (currentWarp == 0) {
const auto lastWarpsNeeded = (warpCount + sgSize - 1) / sgSize;
auto lastAcc = DefaultValue;
#pragma unroll 2
for (int i = 0; i < lastWarpsNeeded; ++i) {
const auto id = threadInWarp + i * sgSize;
auto value = (id < warpCount) ? shmem[id] : DefaultValue;
cgh.parallel_for(sycl::nd_range<1> { numWorkGroups*itemsPerWorkItem, workGroupSize },
[=](sycl::nd_item<1> idx) {

value = sycl::reduce_over_group(subgroup, value, operation);
const auto localId = idx.get_local_id(0);
const auto groupId = idx.get_group(0);

lastAcc = operation(lastAcc, value);
//Thread-local reduction
AccT threadAcc = DefaultValue;
size_t baseIdx = groupId*(workGroupSize*itemsPerWorkItem) + localId;

#pragma unroll
for (std::size_t i = 0; i < itemsPerWorkItem*workGroupSize; i += workGroupSize) {
const auto id = baseIdx + i;
if(id < size){
threadAcc = operation(threadAcc, static_cast<AccT>((buffer[id])));
}
}

if (threadInWarp == 0) {
if (overrideResult) {
ntstore(result, lastAcc);
}
else {
ntstore(result, operation(ntload(result), lastAcc));
}
}
const auto reducedValue = sycl::reduce_over_group(idx.get_group(), threadAcc, operation);

if(localId == 0){
sycl::atomic_ref<AccT, sycl::memory_order::relaxed,
sycl::memory_scope::device,
sycl::access::address_space::global_space> atomicRes(*result);
atomicUpdate<Type>(atomicRes, reducedValue);
}
});
});
Expand Down