@@ -600,30 +600,67 @@ class reduction_impl_algo : public reduction_impl_common<T, BinaryOperation> {
600600 return accessor{*MOutBufPtr, CGH};
601601 }
602602
603- // / If reduction is initialized with read-write accessor, which does not
604- // / require initialization with identity value, then return user's read-write
605- // / accessor. Otherwise, create global buffer with 'num_elements' initialized
606- // / with identity value and return an accessor to that buffer.
607- template <bool HasFastAtomics = (has_fast_atomics || has_float64_atomics),
608- typename = std::enable_if_t <HasFastAtomics>>
609- auto getReadWriteAccessorToInitializedMem (handler &CGH) {
610- if constexpr (!is_usm) {
611- if (!base::initializeToIdentity ())
612- return MRedOut;
613- }
603+ // / Provide \p Func with a properly initialized memory to write the reduction
604+ // / result to. It can either be original user's reduction variable or a newly
605+ // / allocated memory initialized with reduction's identity. In the later case,
606+ // / after the \p Func finishes, update original user's variable accordingly
607+ // / (i.e., honoring initialize_to_identity property).
608+ //
609+ // This currently optimizes for a number of kernel instantiations instead of
610+ // runtime latency. That might change in future.
611+ template <typename KernelName, typename FuncTy>
612+ void withInitializedMem (handler &CGH, FuncTy Func) {
613+ // "Template" lambda to ensure that only one type of Func (USM/Buf) is
614+ // instantiated for the code below.
615+ auto DoIt = [&](auto &Out) {
616+ auto RWReduVal = std::make_shared<std::array<T, num_elements>>();
617+ for (int i = 0 ; i < num_elements; ++i) {
618+ (*RWReduVal)[i] = base::getIdentity ();
619+ }
620+ CGH.addReduction (RWReduVal);
621+ auto Buf = std::make_shared<buffer<T, 1 >>(RWReduVal.get ()->data (),
622+ range<1 >(num_elements));
623+ Buf->set_final_data ();
624+ CGH.addReduction (Buf);
625+ accessor Mem{*Buf, CGH};
626+ Func (Mem);
614627
615- // TODO: Move to T[] in C++20 to simplify handling here
616- // auto RWReduVal = std::make_shared<T[num_elements]>();
617- auto RWReduVal = std::make_shared<std::array<T, num_elements>>();
618- for (int i = 0 ; i < num_elements; ++i) {
619- (*RWReduVal)[i] = base::getIdentity ();
628+ reduction::withAuxHandler (CGH, [&](handler &CopyHandler) {
629+ accessor Mem{*Buf, CopyHandler};
630+ if constexpr (is_usm) {
631+ // Can't capture whole reduction, copy into distinct variables.
632+ bool IsUpdateOfUserVar = !base::initializeToIdentity ();
633+ auto BOp = base::getBinaryOperation ();
634+
635+ // Don't use constexpr as non-default host compilers (unlike clang)
636+ // might actually create a capture resulting in binary differences
637+ // between host/device in lambda captures.
638+ size_t NElements = num_elements;
639+
640+ CopyHandler.single_task <KernelName>([=] {
641+ for (int i = 0 ; i < NElements; ++i) {
642+ if (IsUpdateOfUserVar)
643+ Out[i] = BOp (Out[i], Mem[i]);
644+ else
645+ Out[i] = Mem[i];
646+ }
647+ });
648+ } else {
649+ associateWithHandler (CopyHandler, &Out, access::target::device);
650+ CopyHandler.copy (Mem, Out);
651+ }
652+ });
653+ };
654+ if constexpr (is_usm) {
655+ // Don't dispatch based on base::initializeToIdentity() as that would lead
656+ // to two different instantiations of Func.
657+ DoIt (MRedOut);
658+ } else {
659+ if (base::initializeToIdentity ())
660+ DoIt (MRedOut);
661+ else
662+ Func (MRedOut);
620663 }
621- CGH.addReduction (RWReduVal);
622- MOutBufPtr = std::make_shared<buffer<T, 1 >>(RWReduVal.get ()->data (),
623- range<1 >(num_elements));
624- MOutBufPtr->set_final_data ();
625- CGH.addReduction (MOutBufPtr);
626- return accessor{*MOutBufPtr, CGH};
627664 }
628665
629666 accessor<int , 1 , access::mode::read_write, access::target::device,
@@ -855,45 +892,42 @@ struct NDRangeReduction<reduction::strategy::local_atomic_and_atomic_cross_wg> {
855892 nd_range<Dims> NDRange, PropertiesT &Properties,
856893 Reduction &Redu, KernelType &KernelFunc) {
857894 std::ignore = Queue;
858- size_t NElements = Reduction::num_elements;
859- auto Out = Redu.getReadWriteAccessorToInitializedMem (CGH);
860- local_accessor<typename Reduction::result_type, 1 > GroupSum{NElements, CGH};
895+ Redu.template withInitializedMem <KernelName>(CGH, [&](auto Out) {
896+ size_t NElements = Reduction::num_elements;
897+ local_accessor<typename Reduction::result_type, 1 > GroupSum{NElements,
898+ CGH};
861899
862- using Name = __sycl_reduction_kernel<
863- reduction::MainKrn, KernelName,
864- reduction::strategy::local_atomic_and_atomic_cross_wg>;
900+ using Name = __sycl_reduction_kernel<
901+ reduction::MainKrn, KernelName,
902+ reduction::strategy::local_atomic_and_atomic_cross_wg>;
865903
866- CGH.parallel_for <Name>(NDRange, Properties, [=](nd_item<1 > NDId) {
867- // Call user's functions. Reducer.MValue gets initialized there.
868- typename Reduction::reducer_type Reducer;
869- KernelFunc (NDId, Reducer);
904+ CGH.parallel_for <Name>(NDRange, Properties, [=](nd_item<1 > NDId) {
905+ // Call user's functions. Reducer.MValue gets initialized there.
906+ typename Reduction::reducer_type Reducer;
907+ KernelFunc (NDId, Reducer);
870908
871- // Work-group cooperates to initialize multiple reduction variables
872- auto LID = NDId.get_local_id (0 );
873- for (size_t E = LID; E < NElements; E += NDId.get_local_range (0 )) {
874- GroupSum[E] = Reducer.getIdentity ();
875- }
876- workGroupBarrier ();
909+ // Work-group cooperates to initialize multiple reduction variables
910+ auto LID = NDId.get_local_id (0 );
911+ for (size_t E = LID; E < NElements; E += NDId.get_local_range (0 )) {
912+ GroupSum[E] = Reducer.getIdentity ();
913+ }
914+ workGroupBarrier ();
877915
878- // Each work-item has its own reducer to combine
879- Reducer.template atomic_combine <access::address_space::local_space>(
880- &GroupSum[0 ]);
916+ // Each work-item has its own reducer to combine
917+ Reducer.template atomic_combine <access::address_space::local_space>(
918+ &GroupSum[0 ]);
881919
882- // Single work-item performs finalization for entire work-group
883- // TODO: Opportunity to parallelize across elements
884- workGroupBarrier ();
885- if (LID == 0 ) {
886- for (size_t E = 0 ; E < NElements; ++E) {
887- Reducer.getElement (E) = GroupSum[E];
920+ // Single work-item performs finalization for entire work-group
921+ // TODO: Opportunity to parallelize across elements
922+ workGroupBarrier ();
923+ if (LID == 0 ) {
924+ for (size_t E = 0 ; E < NElements; ++E) {
925+ Reducer.getElement (E) = GroupSum[E];
926+ }
927+ Reducer.template atomic_combine (&Out[0 ]);
888928 }
889- Reducer.template atomic_combine (&Out[0 ]);
890- }
891- });
892-
893- if (Reduction::is_usm || Redu.initializeToIdentity ())
894- reduction::withAuxHandler (CGH, [&](handler &CopyHandler) {
895- reduSaveFinalResultToUserMem<KernelName>(CopyHandler, Redu);
896929 });
930+ });
897931 }
898932};
899933
@@ -1125,32 +1159,27 @@ struct NDRangeReduction<reduction::strategy::group_reduce_and_atomic_cross_wg> {
11251159 nd_range<Dims> NDRange, PropertiesT &Properties,
11261160 Reduction &Redu, KernelType &KernelFunc) {
11271161 std::ignore = Queue;
1128- auto Out = Redu.getReadWriteAccessorToInitializedMem (CGH);
1129- size_t NElements = Reduction::num_elements;
1130-
1131- using Name = __sycl_reduction_kernel<
1132- reduction::MainKrn, KernelName,
1133- reduction::strategy::group_reduce_and_atomic_cross_wg>;
1162+ Redu.template withInitializedMem <KernelName>(CGH, [&](auto Out) {
1163+ size_t NElements = Reduction::num_elements;
11341164
1135- CGH.parallel_for <Name>(NDRange, Properties, [=](nd_item<Dims> NDIt) {
1136- // Call user's function. Reducer.MValue gets initialized there.
1137- typename Reduction::reducer_type Reducer;
1138- KernelFunc (NDIt, Reducer);
1165+ using Name = __sycl_reduction_kernel<
1166+ reduction::MainKrn, KernelName,
1167+ reduction::strategy::group_reduce_and_atomic_cross_wg>;
11391168
1140- typename Reduction::binary_operation BOp;
1141- for (int E = 0 ; E < NElements; ++E) {
1142- Reducer.getElement (E) =
1143- reduce_over_group (NDIt.get_group (), Reducer.getElement (E), BOp);
1144- }
1145- if (NDIt.get_local_linear_id () == 0 )
1146- Reducer.atomic_combine (&Out[0 ]);
1147- });
1169+ CGH.parallel_for <Name>(NDRange, Properties, [=](nd_item<Dims> NDIt) {
1170+ // Call user's function. Reducer.MValue gets initialized there.
1171+ typename Reduction::reducer_type Reducer;
1172+ KernelFunc (NDIt, Reducer);
11481173
1149- if (Reduction::is_usm || Redu.initializeToIdentity ()) {
1150- reduction::withAuxHandler (CGH, [&](handler &CopyHandler) {
1151- reduSaveFinalResultToUserMem<KernelName>(CopyHandler, Redu);
1174+ typename Reduction::binary_operation BOp;
1175+ for (int E = 0 ; E < NElements; ++E) {
1176+ Reducer.getElement (E) =
1177+ reduce_over_group (NDIt.get_group (), Reducer.getElement (E), BOp);
1178+ }
1179+ if (NDIt.get_local_linear_id () == 0 )
1180+ Reducer.atomic_combine (&Out[0 ]);
11521181 });
1153- }
1182+ });
11541183 }
11551184};
11561185
@@ -1163,76 +1192,73 @@ struct NDRangeReduction<
11631192 nd_range<Dims> NDRange, PropertiesT &Properties,
11641193 Reduction &Redu, KernelType &KernelFunc) {
11651194 std::ignore = Queue;
1166- auto Out = Redu.getReadWriteAccessorToInitializedMem (CGH);
1167- size_t NElements = Reduction::num_elements;
1168- size_t WGSize = NDRange.get_local_range ().size ();
1169- bool IsPow2WG = (WGSize & (WGSize - 1 )) == 0 ;
1195+ Redu.template withInitializedMem <KernelName>(CGH, [&](auto Out) {
1196+ size_t NElements = Reduction::num_elements;
1197+ size_t WGSize = NDRange.get_local_range ().size ();
1198+ bool IsPow2WG = (WGSize & (WGSize - 1 )) == 0 ;
1199+
1200+ // Use local memory to reduce elements in work-groups into zero-th
1201+ // element. If WGSize is not power of two, then WGSize+1 elements are
1202+ // allocated. The additional last element is used to catch reduce elements
1203+ // that could otherwise be lost in the tree-reduction algorithm used in
1204+ // the kernel.
1205+ size_t NLocalElements = WGSize + (IsPow2WG ? 0 : 1 );
1206+ local_accessor<typename Reduction::result_type, 1 > LocalReds{
1207+ NLocalElements, CGH};
11701208
1171- // Use local memory to reduce elements in work-groups into zero-th element.
1172- // If WGSize is not power of two, then WGSize+1 elements are allocated.
1173- // The additional last element is used to catch reduce elements that could
1174- // otherwise be lost in the tree-reduction algorithm used in the kernel.
1175- size_t NLocalElements = WGSize + (IsPow2WG ? 0 : 1 );
1176- local_accessor<typename Reduction::result_type, 1 > LocalReds{NLocalElements,
1177- CGH};
1209+ using Name = __sycl_reduction_kernel<
1210+ reduction::MainKrn, KernelName,
1211+ reduction::strategy::local_mem_tree_and_atomic_cross_wg>;
11781212
1179- using Name = __sycl_reduction_kernel<
1180- reduction::MainKrn, KernelName,
1181- reduction::strategy::local_mem_tree_and_atomic_cross_wg>;
1213+ CGH.parallel_for <Name>(NDRange, Properties, [=](nd_item<Dims> NDIt) {
1214+ // Call user's functions. Reducer.MValue gets initialized there.
1215+ typename Reduction::reducer_type Reducer;
1216+ KernelFunc (NDIt, Reducer);
11821217
1183- CGH.parallel_for <Name>(NDRange, Properties, [=](nd_item<Dims> NDIt) {
1184- // Call user's functions. Reducer.MValue gets initialized there.
1185- typename Reduction::reducer_type Reducer;
1186- KernelFunc (NDIt, Reducer);
1218+ size_t WGSize = NDIt.get_local_range ().size ();
1219+ size_t LID = NDIt.get_local_linear_id ();
11871220
1188- size_t WGSize = NDIt.get_local_range ().size ();
1189- size_t LID = NDIt.get_local_linear_id ();
1221+ // If there are multiple values, reduce each separately
1222+ // This prevents local memory from scaling with elements
1223+ for (int E = 0 ; E < NElements; ++E) {
11901224
1191- // If there are multiple values, reduce each separately
1192- // This prevents local memory from scaling with elements
1193- for (int E = 0 ; E < NElements; ++E) {
1225+ // Copy the element to local memory to prepare it for tree-reduction.
1226+ LocalReds[LID] = Reducer.getElement (E);
1227+ if (!IsPow2WG)
1228+ LocalReds[WGSize] = Reducer.getIdentity ();
1229+ NDIt.barrier ();
11941230
1195- // Copy the element to local memory to prepare it for tree-reduction.
1196- LocalReds[LID] = Reducer.getElement (E);
1197- if (!IsPow2WG)
1198- LocalReds[WGSize] = Reducer.getIdentity ();
1199- NDIt.barrier ();
1231+ // Tree-reduction: reduce the local array LocalReds[:] to
1232+ // LocalReds[0]. LocalReds[WGSize] accumulates last/odd elements when
1233+ // the step of tree-reduction loop is not even.
1234+ typename Reduction::binary_operation BOp;
1235+ size_t PrevStep = WGSize;
1236+ for (size_t CurStep = PrevStep >> 1 ; CurStep > 0 ; CurStep >>= 1 ) {
1237+ if (LID < CurStep)
1238+ LocalReds[LID] = BOp (LocalReds[LID], LocalReds[LID + CurStep]);
1239+ else if (!IsPow2WG && LID == CurStep && (PrevStep & 0x1 ))
1240+ LocalReds[WGSize] =
1241+ BOp (LocalReds[WGSize], LocalReds[PrevStep - 1 ]);
1242+ NDIt.barrier ();
1243+ PrevStep = CurStep;
1244+ }
12001245
1201- // Tree-reduction: reduce the local array LocalReds[:] to LocalReds[0].
1202- // LocalReds[WGSize] accumulates last/odd elements when the step
1203- // of tree-reduction loop is not even.
1204- typename Reduction::binary_operation BOp;
1205- size_t PrevStep = WGSize;
1206- for (size_t CurStep = PrevStep >> 1 ; CurStep > 0 ; CurStep >>= 1 ) {
1207- if (LID < CurStep)
1208- LocalReds[LID] = BOp (LocalReds[LID], LocalReds[LID + CurStep]);
1209- else if (!IsPow2WG && LID == CurStep && (PrevStep & 0x1 ))
1210- LocalReds[WGSize] = BOp (LocalReds[WGSize], LocalReds[PrevStep - 1 ]);
1211- NDIt.barrier ();
1212- PrevStep = CurStep;
1213- }
1246+ if (LID == 0 ) {
1247+ Reducer.getElement (E) =
1248+ IsPow2WG ? LocalReds[0 ] : BOp (LocalReds[0 ], LocalReds[WGSize]);
1249+ }
12141250
1215- if (LID == 0 ) {
1216- Reducer.getElement (E) =
1217- IsPow2WG ? LocalReds[0 ] : BOp (LocalReds[0 ], LocalReds[WGSize]);
1251+ // Ensure item 0 is finished with LocalReds before next iteration
1252+ if (E != NElements - 1 ) {
1253+ NDIt.barrier ();
1254+ }
12181255 }
12191256
1220- // Ensure item 0 is finished with LocalReds before next iteration
1221- if (E != NElements - 1 ) {
1222- NDIt.barrier ();
1257+ if (LID == 0 ) {
1258+ Reducer.atomic_combine (&Out[0 ]);
12231259 }
1224- }
1225-
1226- if (LID == 0 ) {
1227- Reducer.atomic_combine (&Out[0 ]);
1228- }
1229- });
1230-
1231- if (Reduction::is_usm || Redu.initializeToIdentity ()) {
1232- reduction::withAuxHandler (CGH, [&](handler &CopyHandler) {
1233- reduSaveFinalResultToUserMem<KernelName>(CopyHandler, Redu);
12341260 });
1235- }
1261+ });
12361262 }
12371263};
12381264
0 commit comments