@@ -144,8 +144,8 @@ __SYCL_EXPORT size_t reduGetMaxWGSize(std::shared_ptr<queue_impl> Queue,
144144 size_t LocalMemBytesPerWorkItem);
145145__SYCL_EXPORT size_t reduComputeWGSize (size_t NWorkItems, size_t MaxWGSize,
146146 size_t &NWorkGroups);
147- __SYCL_EXPORT size_t reduGetPreferredWGSize (std::shared_ptr<queue_impl> &Queue,
148- size_t LocalMemBytesPerWorkItem);
147+ __SYCL_EXPORT size_t reduGetPreferredDeviceWGSize (
148+ std::shared_ptr<queue_impl> &Queue, size_t LocalMemBytesPerWorkItem);
149149
150150template <typename T, class BinaryOperation , bool IsOptional>
151151class ReducerElement ;
@@ -1200,6 +1200,25 @@ void reduSaveFinalResultToUserMem(handler &CGH, Reduction &Redu) {
12001200 });
12011201}
12021202
1203+ template <typename KernelName>
1204+ size_t reduGetPreferredKernelWGSize (std::shared_ptr<queue_impl> &Queue) {
1205+ using namespace info ::kernel_device_specific;
1206+ auto SyclQueue = createSyclObjFromImpl<queue>(Queue);
1207+ auto Ctx = SyclQueue.get_context ();
1208+ auto Dev = SyclQueue.get_device ();
1209+ size_t MaxWGSize = SIZE_MAX;
1210+ constexpr bool IsUndefinedKernelName{std::is_same_v<KernelName, auto_name>};
1211+
1212+ if (!IsUndefinedKernelName) {
1213+ auto ExecBundle =
1214+ get_kernel_bundle<KernelName, bundle_state::executable>(Ctx, {Dev});
1215+ kernel Kernel = ExecBundle.template get_kernel <KernelName>();
1216+ MaxWGSize = Kernel.template get_info <work_group_size>(Dev);
1217+ }
1218+
1219+ return MaxWGSize;
1220+ }
1221+
12031222namespace reduction {
12041223template <typename KernelName, strategy S, class ... Ts> struct MainKrn ;
12051224template <typename KernelName, strategy S, class ... Ts> struct AuxKrn ;
@@ -1302,6 +1321,8 @@ struct NDRangeReduction<
13021321 reduction::strategy::group_reduce_and_last_wg_detection,
13031322 decltype (NWorkGroupsFinished)>;
13041323
1324+ WGSize = std::min (WGSize, reduGetPreferredKernelWGSize<Name>(Queue));
1325+
13051326 CGH.parallel_for <Name>(NDRange, Properties, [=](nd_item<1 > NDId) {
13061327 // Call user's functions. Reducer.MValue gets initialized there.
13071328 typename Reduction::reducer_type Reducer;
@@ -1515,6 +1536,8 @@ template <> struct NDRangeReduction<reduction::strategy::range_basic> {
15151536 using Name = __sycl_reduction_kernel<reduction::MainKrn, KernelName,
15161537 reduction::strategy::range_basic>;
15171538
1539+ WGSize = std::min (WGSize, reduGetPreferredKernelWGSize<Name>(Queue));
1540+
15181541 CGH.parallel_for <Name>(NDRange, Properties, [=](nd_item<1 > NDId) {
15191542 // Call user's functions. Reducer.MValue gets initialized there.
15201543 reducer_type Reducer = reducer_type (IdentityContainer, BOp);
@@ -1628,14 +1651,14 @@ struct NDRangeReduction<
16281651 using reducer_type = typename Reduction::reducer_type;
16291652 using element_type = typename ReducerTraits<reducer_type>::element_type;
16301653
1631- std::ignore = Queue;
16321654 using Name = __sycl_reduction_kernel<
16331655 reduction::MainKrn, KernelName,
16341656 reduction::strategy::local_mem_tree_and_atomic_cross_wg>;
16351657 Redu.template withInitializedMem <Name>(CGH, [&](auto Out) {
16361658 size_t NElements = Reduction::num_elements;
16371659 size_t WGSize = NDRange.get_local_range ().size ();
16381660
1661+ WGSize = std::min (WGSize, reduGetPreferredKernelWGSize<Name>(Queue));
16391662 // Use local memory to reduce elements in work-groups into zero-th
16401663 // element.
16411664 local_accessor<element_type, 1 > LocalReds{WGSize, CGH};
@@ -1722,6 +1745,8 @@ struct NDRangeReduction<
17221745 reduction::MainKrn, KernelName,
17231746 reduction::strategy::group_reduce_and_multiple_kernels>;
17241747
1748+ MaxWGSize = std::min (MaxWGSize, reduGetPreferredKernelWGSize<Name>(Queue));
1749+
17251750 CGH.parallel_for <Name>(NDRange, Properties, [=](nd_item<Dims> NDIt) {
17261751 // Call user's functions. Reducer.MValue gets initialized there.
17271752 typename Reduction::reducer_type Reducer;
@@ -1781,6 +1806,8 @@ struct NDRangeReduction<
17811806 reduction::AuxKrn, KernelName,
17821807 reduction::strategy::group_reduce_and_multiple_kernels>;
17831808
1809+ WGSize = std::min (WGSize, reduGetPreferredKernelWGSize<Name>(Queue));
1810+
17841811 bool IsUpdateOfUserVar = !Reduction::is_usm &&
17851812 !Redu.initializeToIdentity () &&
17861813 NWorkGroups == 1 ;
@@ -1874,6 +1901,9 @@ template <> struct NDRangeReduction<reduction::strategy::basic> {
18741901 reduction::strategy::basic,
18751902 decltype (KernelTag)>;
18761903
1904+ MaxWGSize =
1905+ std::min (MaxWGSize, reduGetPreferredKernelWGSize<Name>(Queue));
1906+
18771907 CGH.parallel_for <Name>(NDRange, Properties, [=](nd_item<Dims> NDIt) {
18781908 // Call user's functions. Reducer.MValue gets initialized there.
18791909 typename Reduction::reducer_type Reducer =
@@ -1978,6 +2008,8 @@ template <> struct NDRangeReduction<reduction::strategy::basic> {
19782008 reduction::strategy::basic,
19792009 decltype (KernelTag)>;
19802010
2011+ WGSize = std::min (WGSize, reduGetPreferredKernelWGSize<Name>(Queue));
2012+
19812013 range<1 > GlobalRange = {UniformPow2WG ? NWorkItems
19822014 : NWorkGroups * WGSize};
19832015 nd_range<1 > Range{GlobalRange, range<1 >(WGSize)};
@@ -2295,8 +2327,9 @@ template <class KernelName, class Accessor> struct NDRangeMulti;
22952327} // namespace reduction::main_krn
22962328template <typename KernelName, typename KernelType, int Dims,
22972329 typename PropertiesT, typename ... Reductions, size_t ... Is>
2298- void reduCGFuncMulti (handler &CGH, KernelType KernelFunc,
2299- const nd_range<Dims> &Range, PropertiesT Properties,
2330+ void reduCGFuncMulti (handler &CGH, std::shared_ptr<detail::queue_impl> &Queue,
2331+ KernelType KernelFunc, const nd_range<Dims> &Range,
2332+ PropertiesT Properties,
23002333 std::tuple<Reductions...> &ReduTuple,
23012334 std::index_sequence<Is...> ReduIndices) {
23022335 size_t WGSize = Range.get_local_range ().size ();
@@ -2334,6 +2367,8 @@ void reduCGFuncMulti(handler &CGH, KernelType KernelFunc,
23342367 reduction::strategy::multi,
23352368 decltype (KernelTag)>;
23362369
2370+ WGSize = std::min (WGSize, reduGetPreferredKernelWGSize<Name>(Queue));
2371+
23372372 CGH.parallel_for <Name>(Range, Properties, [=](nd_item<Dims> NDIt) {
23382373 // We can deduce IsOneWG from the tag type.
23392374 constexpr bool IsOneWG =
@@ -2495,7 +2530,8 @@ template <class KernelName, class Predicate> struct Multi;
24952530} // namespace reduction::aux_krn
24962531template <typename KernelName, typename KernelType, typename ... Reductions,
24972532 size_t ... Is>
2498- size_t reduAuxCGFunc (handler &CGH, size_t NWorkItems, size_t MaxWGSize,
2533+ size_t reduAuxCGFunc (handler &CGH, std::shared_ptr<queue_impl> &Queue,
2534+ size_t NWorkItems, size_t MaxWGSize,
24992535 std::tuple<Reductions...> &ReduTuple,
25002536 std::index_sequence<Is...> ReduIndices) {
25012537 size_t NWorkGroups;
@@ -2533,6 +2569,8 @@ size_t reduAuxCGFunc(handler &CGH, size_t NWorkItems, size_t MaxWGSize,
25332569 using Name = __sycl_reduction_kernel<reduction::AuxKrn, KernelName,
25342570 reduction::strategy::multi,
25352571 decltype (Predicate)>;
2572+ WGSize = std::min (WGSize, reduGetPreferredKernelWGSize<Name>(Queue));
2573+
25362574 // TODO: Opportunity to parallelize across number of elements
25372575 range<1 > GlobalRange = {HasUniformWG ? NWorkItems : NWorkGroups * WGSize};
25382576 nd_range<1 > Range{GlobalRange, range<1 >(WGSize)};
@@ -2617,15 +2655,15 @@ template <> struct NDRangeReduction<reduction::strategy::multi> {
26172655 " than " +
26182656 std::to_string (MaxWGSize));
26192657
2620- reduCGFuncMulti<KernelName>(CGH, KernelFunc, NDRange, Properties, ReduTuple ,
2621- ReduIndices);
2658+ reduCGFuncMulti<KernelName>(CGH, Queue, KernelFunc, NDRange, Properties,
2659+ ReduTuple, ReduIndices);
26222660 reduction::finalizeHandler (CGH);
26232661
26242662 size_t NWorkItems = NDRange.get_group_range ().size ();
26252663 while (NWorkItems > 1 ) {
26262664 reduction::withAuxHandler (CGH, [&](handler &AuxHandler) {
26272665 NWorkItems = reduAuxCGFunc<KernelName, decltype (KernelFunc)>(
2628- AuxHandler, NWorkItems, MaxWGSize, ReduTuple, ReduIndices);
2666+ AuxHandler, Queue, NWorkItems, MaxWGSize, ReduTuple, ReduIndices);
26292667 });
26302668 } // end while (NWorkItems > 1)
26312669 }
@@ -2741,7 +2779,29 @@ void reduction_parallel_for(handler &CGH, range<Dims> Range,
27412779 // TODO: currently the preferred work group size is determined for the given
27422780 // queue/device, while it is safer to use queries to the kernel pre-compiled
27432781 // for the device.
2744- size_t PrefWGSize = reduGetPreferredWGSize (CGH.MQueue , OneElemSize);
2782+ size_t PrefWGSize = reduGetPreferredDeviceWGSize (CGH.MQueue , OneElemSize);
2783+
2784+ auto SyclQueue = createSyclObjFromImpl<queue>(CGH.MQueue );
2785+ auto Ctx = SyclQueue.get_context ();
2786+ auto Dev = SyclQueue.get_device ();
2787+
2788+ // If the reduction kernel is not name defined, we won't be able to query the
2789+ // exact kernel for the best wgsize, so we query all the reduction kernels for
2790+ // thier wgsize and use the minimum wgsize as a safe and approximate option.
2791+ constexpr bool IsUndefinedKernelName{std::is_same_v<KernelName, auto_name>};
2792+ if (IsUndefinedKernelName) {
2793+ std::vector<kernel_id> ReductionKernelIDs = get_kernel_ids ();
2794+ for (auto KernelID : ReductionKernelIDs) {
2795+ std::string ReduKernelName = KernelID.get_name ();
2796+ if (ReduKernelName.find (" reduction" ) != std::string::npos) {
2797+ auto KB = get_kernel_bundle<bundle_state::executable>(Ctx, {KernelID});
2798+ kernel krn = KB.get_kernel (KernelID);
2799+ using namespace info ::kernel_device_specific;
2800+ size_t MaxSize = krn.template get_info <work_group_size>(Dev);
2801+ PrefWGSize = std::min (PrefWGSize, MaxSize);
2802+ }
2803+ }
2804+ }
27452805
27462806 size_t NWorkItems = Range.size ();
27472807 size_t WGSize = std::min (NWorkItems, PrefWGSize);
0 commit comments