Skip to content

Commit 4bebb35

Browse files
[NFC][SYCL][Reduction] Rethink getReadWriteAccessorToInitializedMem helper (#7358)
Previously the helper abstraction consisted of two implicitly tied pieces that had to be used together and used member variable MOutBufPtr to pass some data behind the scene. Change that to a new abstraction withInitializedMem accepting a callable argument to make the data flow explicit and more immediately clear. This change also allows to change/optimize this piece in one place without the need to modify multiple uses.
1 parent f90d2b4 commit 4bebb35

File tree

1 file changed

+160
-134
lines changed

1 file changed

+160
-134
lines changed

sycl/include/sycl/reduction.hpp

Lines changed: 160 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -600,30 +600,67 @@ class reduction_impl_algo : public reduction_impl_common<T, BinaryOperation> {
600600
return accessor{*MOutBufPtr, CGH};
601601
}
602602

603-
/// If reduction is initialized with read-write accessor, which does not
604-
/// require initialization with identity value, then return user's read-write
605-
/// accessor. Otherwise, create global buffer with 'num_elements' initialized
606-
/// with identity value and return an accessor to that buffer.
607-
template <bool HasFastAtomics = (has_fast_atomics || has_float64_atomics),
608-
typename = std::enable_if_t<HasFastAtomics>>
609-
auto getReadWriteAccessorToInitializedMem(handler &CGH) {
610-
if constexpr (!is_usm) {
611-
if (!base::initializeToIdentity())
612-
return MRedOut;
613-
}
603+
/// Provide \p Func with a properly initialized memory to write the reduction
604+
/// result to. It can either be original user's reduction variable or a newly
605+
/// allocated memory initialized with reduction's identity. In the later case,
606+
/// after the \p Func finishes, update original user's variable accordingly
607+
/// (i.e., honoring initialize_to_identity property).
608+
//
609+
// This currently optimizes for a number of kernel instantiations instead of
610+
// runtime latency. That might change in future.
611+
template <typename KernelName, typename FuncTy>
612+
void withInitializedMem(handler &CGH, FuncTy Func) {
613+
// "Template" lambda to ensure that only one type of Func (USM/Buf) is
614+
// instantiated for the code below.
615+
auto DoIt = [&](auto &Out) {
616+
auto RWReduVal = std::make_shared<std::array<T, num_elements>>();
617+
for (int i = 0; i < num_elements; ++i) {
618+
(*RWReduVal)[i] = base::getIdentity();
619+
}
620+
CGH.addReduction(RWReduVal);
621+
auto Buf = std::make_shared<buffer<T, 1>>(RWReduVal.get()->data(),
622+
range<1>(num_elements));
623+
Buf->set_final_data();
624+
CGH.addReduction(Buf);
625+
accessor Mem{*Buf, CGH};
626+
Func(Mem);
614627

615-
// TODO: Move to T[] in C++20 to simplify handling here
616-
// auto RWReduVal = std::make_shared<T[num_elements]>();
617-
auto RWReduVal = std::make_shared<std::array<T, num_elements>>();
618-
for (int i = 0; i < num_elements; ++i) {
619-
(*RWReduVal)[i] = base::getIdentity();
628+
reduction::withAuxHandler(CGH, [&](handler &CopyHandler) {
629+
accessor Mem{*Buf, CopyHandler};
630+
if constexpr (is_usm) {
631+
// Can't capture whole reduction, copy into distinct variables.
632+
bool IsUpdateOfUserVar = !base::initializeToIdentity();
633+
auto BOp = base::getBinaryOperation();
634+
635+
// Don't use constexpr as non-default host compilers (unlike clang)
636+
// might actually create a capture resulting in binary differences
637+
// between host/device in lambda captures.
638+
size_t NElements = num_elements;
639+
640+
CopyHandler.single_task<KernelName>([=] {
641+
for (int i = 0; i < NElements; ++i) {
642+
if (IsUpdateOfUserVar)
643+
Out[i] = BOp(Out[i], Mem[i]);
644+
else
645+
Out[i] = Mem[i];
646+
}
647+
});
648+
} else {
649+
associateWithHandler(CopyHandler, &Out, access::target::device);
650+
CopyHandler.copy(Mem, Out);
651+
}
652+
});
653+
};
654+
if constexpr (is_usm) {
655+
// Don't dispatch based on base::initializeToIdentity() as that would lead
656+
// to two different instantiations of Func.
657+
DoIt(MRedOut);
658+
} else {
659+
if (base::initializeToIdentity())
660+
DoIt(MRedOut);
661+
else
662+
Func(MRedOut);
620663
}
621-
CGH.addReduction(RWReduVal);
622-
MOutBufPtr = std::make_shared<buffer<T, 1>>(RWReduVal.get()->data(),
623-
range<1>(num_elements));
624-
MOutBufPtr->set_final_data();
625-
CGH.addReduction(MOutBufPtr);
626-
return accessor{*MOutBufPtr, CGH};
627664
}
628665

629666
accessor<int, 1, access::mode::read_write, access::target::device,
@@ -855,45 +892,42 @@ struct NDRangeReduction<reduction::strategy::local_atomic_and_atomic_cross_wg> {
855892
nd_range<Dims> NDRange, PropertiesT &Properties,
856893
Reduction &Redu, KernelType &KernelFunc) {
857894
std::ignore = Queue;
858-
size_t NElements = Reduction::num_elements;
859-
auto Out = Redu.getReadWriteAccessorToInitializedMem(CGH);
860-
local_accessor<typename Reduction::result_type, 1> GroupSum{NElements, CGH};
895+
Redu.template withInitializedMem<KernelName>(CGH, [&](auto Out) {
896+
size_t NElements = Reduction::num_elements;
897+
local_accessor<typename Reduction::result_type, 1> GroupSum{NElements,
898+
CGH};
861899

862-
using Name = __sycl_reduction_kernel<
863-
reduction::MainKrn, KernelName,
864-
reduction::strategy::local_atomic_and_atomic_cross_wg>;
900+
using Name = __sycl_reduction_kernel<
901+
reduction::MainKrn, KernelName,
902+
reduction::strategy::local_atomic_and_atomic_cross_wg>;
865903

866-
CGH.parallel_for<Name>(NDRange, Properties, [=](nd_item<1> NDId) {
867-
// Call user's functions. Reducer.MValue gets initialized there.
868-
typename Reduction::reducer_type Reducer;
869-
KernelFunc(NDId, Reducer);
904+
CGH.parallel_for<Name>(NDRange, Properties, [=](nd_item<1> NDId) {
905+
// Call user's functions. Reducer.MValue gets initialized there.
906+
typename Reduction::reducer_type Reducer;
907+
KernelFunc(NDId, Reducer);
870908

871-
// Work-group cooperates to initialize multiple reduction variables
872-
auto LID = NDId.get_local_id(0);
873-
for (size_t E = LID; E < NElements; E += NDId.get_local_range(0)) {
874-
GroupSum[E] = Reducer.getIdentity();
875-
}
876-
workGroupBarrier();
909+
// Work-group cooperates to initialize multiple reduction variables
910+
auto LID = NDId.get_local_id(0);
911+
for (size_t E = LID; E < NElements; E += NDId.get_local_range(0)) {
912+
GroupSum[E] = Reducer.getIdentity();
913+
}
914+
workGroupBarrier();
877915

878-
// Each work-item has its own reducer to combine
879-
Reducer.template atomic_combine<access::address_space::local_space>(
880-
&GroupSum[0]);
916+
// Each work-item has its own reducer to combine
917+
Reducer.template atomic_combine<access::address_space::local_space>(
918+
&GroupSum[0]);
881919

882-
// Single work-item performs finalization for entire work-group
883-
// TODO: Opportunity to parallelize across elements
884-
workGroupBarrier();
885-
if (LID == 0) {
886-
for (size_t E = 0; E < NElements; ++E) {
887-
Reducer.getElement(E) = GroupSum[E];
920+
// Single work-item performs finalization for entire work-group
921+
// TODO: Opportunity to parallelize across elements
922+
workGroupBarrier();
923+
if (LID == 0) {
924+
for (size_t E = 0; E < NElements; ++E) {
925+
Reducer.getElement(E) = GroupSum[E];
926+
}
927+
Reducer.template atomic_combine(&Out[0]);
888928
}
889-
Reducer.template atomic_combine(&Out[0]);
890-
}
891-
});
892-
893-
if (Reduction::is_usm || Redu.initializeToIdentity())
894-
reduction::withAuxHandler(CGH, [&](handler &CopyHandler) {
895-
reduSaveFinalResultToUserMem<KernelName>(CopyHandler, Redu);
896929
});
930+
});
897931
}
898932
};
899933

@@ -1125,32 +1159,27 @@ struct NDRangeReduction<reduction::strategy::group_reduce_and_atomic_cross_wg> {
11251159
nd_range<Dims> NDRange, PropertiesT &Properties,
11261160
Reduction &Redu, KernelType &KernelFunc) {
11271161
std::ignore = Queue;
1128-
auto Out = Redu.getReadWriteAccessorToInitializedMem(CGH);
1129-
size_t NElements = Reduction::num_elements;
1130-
1131-
using Name = __sycl_reduction_kernel<
1132-
reduction::MainKrn, KernelName,
1133-
reduction::strategy::group_reduce_and_atomic_cross_wg>;
1162+
Redu.template withInitializedMem<KernelName>(CGH, [&](auto Out) {
1163+
size_t NElements = Reduction::num_elements;
11341164

1135-
CGH.parallel_for<Name>(NDRange, Properties, [=](nd_item<Dims> NDIt) {
1136-
// Call user's function. Reducer.MValue gets initialized there.
1137-
typename Reduction::reducer_type Reducer;
1138-
KernelFunc(NDIt, Reducer);
1165+
using Name = __sycl_reduction_kernel<
1166+
reduction::MainKrn, KernelName,
1167+
reduction::strategy::group_reduce_and_atomic_cross_wg>;
11391168

1140-
typename Reduction::binary_operation BOp;
1141-
for (int E = 0; E < NElements; ++E) {
1142-
Reducer.getElement(E) =
1143-
reduce_over_group(NDIt.get_group(), Reducer.getElement(E), BOp);
1144-
}
1145-
if (NDIt.get_local_linear_id() == 0)
1146-
Reducer.atomic_combine(&Out[0]);
1147-
});
1169+
CGH.parallel_for<Name>(NDRange, Properties, [=](nd_item<Dims> NDIt) {
1170+
// Call user's function. Reducer.MValue gets initialized there.
1171+
typename Reduction::reducer_type Reducer;
1172+
KernelFunc(NDIt, Reducer);
11481173

1149-
if (Reduction::is_usm || Redu.initializeToIdentity()) {
1150-
reduction::withAuxHandler(CGH, [&](handler &CopyHandler) {
1151-
reduSaveFinalResultToUserMem<KernelName>(CopyHandler, Redu);
1174+
typename Reduction::binary_operation BOp;
1175+
for (int E = 0; E < NElements; ++E) {
1176+
Reducer.getElement(E) =
1177+
reduce_over_group(NDIt.get_group(), Reducer.getElement(E), BOp);
1178+
}
1179+
if (NDIt.get_local_linear_id() == 0)
1180+
Reducer.atomic_combine(&Out[0]);
11521181
});
1153-
}
1182+
});
11541183
}
11551184
};
11561185

@@ -1163,76 +1192,73 @@ struct NDRangeReduction<
11631192
nd_range<Dims> NDRange, PropertiesT &Properties,
11641193
Reduction &Redu, KernelType &KernelFunc) {
11651194
std::ignore = Queue;
1166-
auto Out = Redu.getReadWriteAccessorToInitializedMem(CGH);
1167-
size_t NElements = Reduction::num_elements;
1168-
size_t WGSize = NDRange.get_local_range().size();
1169-
bool IsPow2WG = (WGSize & (WGSize - 1)) == 0;
1195+
Redu.template withInitializedMem<KernelName>(CGH, [&](auto Out) {
1196+
size_t NElements = Reduction::num_elements;
1197+
size_t WGSize = NDRange.get_local_range().size();
1198+
bool IsPow2WG = (WGSize & (WGSize - 1)) == 0;
1199+
1200+
// Use local memory to reduce elements in work-groups into zero-th
1201+
// element. If WGSize is not power of two, then WGSize+1 elements are
1202+
// allocated. The additional last element is used to catch reduce elements
1203+
// that could otherwise be lost in the tree-reduction algorithm used in
1204+
// the kernel.
1205+
size_t NLocalElements = WGSize + (IsPow2WG ? 0 : 1);
1206+
local_accessor<typename Reduction::result_type, 1> LocalReds{
1207+
NLocalElements, CGH};
11701208

1171-
// Use local memory to reduce elements in work-groups into zero-th element.
1172-
// If WGSize is not power of two, then WGSize+1 elements are allocated.
1173-
// The additional last element is used to catch reduce elements that could
1174-
// otherwise be lost in the tree-reduction algorithm used in the kernel.
1175-
size_t NLocalElements = WGSize + (IsPow2WG ? 0 : 1);
1176-
local_accessor<typename Reduction::result_type, 1> LocalReds{NLocalElements,
1177-
CGH};
1209+
using Name = __sycl_reduction_kernel<
1210+
reduction::MainKrn, KernelName,
1211+
reduction::strategy::local_mem_tree_and_atomic_cross_wg>;
11781212

1179-
using Name = __sycl_reduction_kernel<
1180-
reduction::MainKrn, KernelName,
1181-
reduction::strategy::local_mem_tree_and_atomic_cross_wg>;
1213+
CGH.parallel_for<Name>(NDRange, Properties, [=](nd_item<Dims> NDIt) {
1214+
// Call user's functions. Reducer.MValue gets initialized there.
1215+
typename Reduction::reducer_type Reducer;
1216+
KernelFunc(NDIt, Reducer);
11821217

1183-
CGH.parallel_for<Name>(NDRange, Properties, [=](nd_item<Dims> NDIt) {
1184-
// Call user's functions. Reducer.MValue gets initialized there.
1185-
typename Reduction::reducer_type Reducer;
1186-
KernelFunc(NDIt, Reducer);
1218+
size_t WGSize = NDIt.get_local_range().size();
1219+
size_t LID = NDIt.get_local_linear_id();
11871220

1188-
size_t WGSize = NDIt.get_local_range().size();
1189-
size_t LID = NDIt.get_local_linear_id();
1221+
// If there are multiple values, reduce each separately
1222+
// This prevents local memory from scaling with elements
1223+
for (int E = 0; E < NElements; ++E) {
11901224

1191-
// If there are multiple values, reduce each separately
1192-
// This prevents local memory from scaling with elements
1193-
for (int E = 0; E < NElements; ++E) {
1225+
// Copy the element to local memory to prepare it for tree-reduction.
1226+
LocalReds[LID] = Reducer.getElement(E);
1227+
if (!IsPow2WG)
1228+
LocalReds[WGSize] = Reducer.getIdentity();
1229+
NDIt.barrier();
11941230

1195-
// Copy the element to local memory to prepare it for tree-reduction.
1196-
LocalReds[LID] = Reducer.getElement(E);
1197-
if (!IsPow2WG)
1198-
LocalReds[WGSize] = Reducer.getIdentity();
1199-
NDIt.barrier();
1231+
// Tree-reduction: reduce the local array LocalReds[:] to
1232+
// LocalReds[0]. LocalReds[WGSize] accumulates last/odd elements when
1233+
// the step of tree-reduction loop is not even.
1234+
typename Reduction::binary_operation BOp;
1235+
size_t PrevStep = WGSize;
1236+
for (size_t CurStep = PrevStep >> 1; CurStep > 0; CurStep >>= 1) {
1237+
if (LID < CurStep)
1238+
LocalReds[LID] = BOp(LocalReds[LID], LocalReds[LID + CurStep]);
1239+
else if (!IsPow2WG && LID == CurStep && (PrevStep & 0x1))
1240+
LocalReds[WGSize] =
1241+
BOp(LocalReds[WGSize], LocalReds[PrevStep - 1]);
1242+
NDIt.barrier();
1243+
PrevStep = CurStep;
1244+
}
12001245

1201-
// Tree-reduction: reduce the local array LocalReds[:] to LocalReds[0].
1202-
// LocalReds[WGSize] accumulates last/odd elements when the step
1203-
// of tree-reduction loop is not even.
1204-
typename Reduction::binary_operation BOp;
1205-
size_t PrevStep = WGSize;
1206-
for (size_t CurStep = PrevStep >> 1; CurStep > 0; CurStep >>= 1) {
1207-
if (LID < CurStep)
1208-
LocalReds[LID] = BOp(LocalReds[LID], LocalReds[LID + CurStep]);
1209-
else if (!IsPow2WG && LID == CurStep && (PrevStep & 0x1))
1210-
LocalReds[WGSize] = BOp(LocalReds[WGSize], LocalReds[PrevStep - 1]);
1211-
NDIt.barrier();
1212-
PrevStep = CurStep;
1213-
}
1246+
if (LID == 0) {
1247+
Reducer.getElement(E) =
1248+
IsPow2WG ? LocalReds[0] : BOp(LocalReds[0], LocalReds[WGSize]);
1249+
}
12141250

1215-
if (LID == 0) {
1216-
Reducer.getElement(E) =
1217-
IsPow2WG ? LocalReds[0] : BOp(LocalReds[0], LocalReds[WGSize]);
1251+
// Ensure item 0 is finished with LocalReds before next iteration
1252+
if (E != NElements - 1) {
1253+
NDIt.barrier();
1254+
}
12181255
}
12191256

1220-
// Ensure item 0 is finished with LocalReds before next iteration
1221-
if (E != NElements - 1) {
1222-
NDIt.barrier();
1257+
if (LID == 0) {
1258+
Reducer.atomic_combine(&Out[0]);
12231259
}
1224-
}
1225-
1226-
if (LID == 0) {
1227-
Reducer.atomic_combine(&Out[0]);
1228-
}
1229-
});
1230-
1231-
if (Reduction::is_usm || Redu.initializeToIdentity()) {
1232-
reduction::withAuxHandler(CGH, [&](handler &CopyHandler) {
1233-
reduSaveFinalResultToUserMem<KernelName>(CopyHandler, Redu);
12341260
});
1235-
}
1261+
});
12361262
}
12371263
};
12381264

0 commit comments

Comments
 (0)