Skip to content

Commit 71739a8

Browse files
omarahmed1111Georgi Mirazchiyski
andcommitted
Enhance querying kernels preferred wgsize
Co-authored-by: Georgi Mirazchiyski <[email protected]>
1 parent 08a2edc commit 71739a8

File tree

4 files changed

+74
-14
lines changed

4 files changed

+74
-14
lines changed

sycl/include/sycl/reduction.hpp

Lines changed: 70 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,8 @@ __SYCL_EXPORT size_t reduGetMaxWGSize(std::shared_ptr<queue_impl> Queue,
144144
size_t LocalMemBytesPerWorkItem);
145145
__SYCL_EXPORT size_t reduComputeWGSize(size_t NWorkItems, size_t MaxWGSize,
146146
size_t &NWorkGroups);
147-
__SYCL_EXPORT size_t reduGetPreferredWGSize(std::shared_ptr<queue_impl> &Queue,
148-
size_t LocalMemBytesPerWorkItem);
147+
__SYCL_EXPORT size_t reduGetPreferredDeviceWGSize(
148+
std::shared_ptr<queue_impl> &Queue, size_t LocalMemBytesPerWorkItem);
149149

150150
template <typename T, class BinaryOperation, bool IsOptional>
151151
class ReducerElement;
@@ -1200,6 +1200,25 @@ void reduSaveFinalResultToUserMem(handler &CGH, Reduction &Redu) {
12001200
});
12011201
}
12021202

1203+
template <typename KernelName>
1204+
size_t reduGetPreferredKernelWGSize(std::shared_ptr<queue_impl> &Queue) {
1205+
using namespace info::kernel_device_specific;
1206+
auto SyclQueue = createSyclObjFromImpl<queue>(Queue);
1207+
auto Ctx = SyclQueue.get_context();
1208+
auto Dev = SyclQueue.get_device();
1209+
size_t MaxWGSize = SIZE_MAX;
1210+
constexpr bool IsUndefinedKernelName{std::is_same_v<KernelName, auto_name>};
1211+
1212+
if (!IsUndefinedKernelName) {
1213+
auto ExecBundle =
1214+
get_kernel_bundle<KernelName, bundle_state::executable>(Ctx, {Dev});
1215+
kernel Kernel = ExecBundle.template get_kernel<KernelName>();
1216+
MaxWGSize = Kernel.template get_info<work_group_size>(Dev);
1217+
}
1218+
1219+
return MaxWGSize;
1220+
}
1221+
12031222
namespace reduction {
12041223
template <typename KernelName, strategy S, class... Ts> struct MainKrn;
12051224
template <typename KernelName, strategy S, class... Ts> struct AuxKrn;
@@ -1302,6 +1321,8 @@ struct NDRangeReduction<
13021321
reduction::strategy::group_reduce_and_last_wg_detection,
13031322
decltype(NWorkGroupsFinished)>;
13041323

1324+
WGSize = std::min(WGSize, reduGetPreferredKernelWGSize<Name>(Queue));
1325+
13051326
CGH.parallel_for<Name>(NDRange, Properties, [=](nd_item<1> NDId) {
13061327
// Call user's functions. Reducer.MValue gets initialized there.
13071328
typename Reduction::reducer_type Reducer;
@@ -1515,6 +1536,8 @@ template <> struct NDRangeReduction<reduction::strategy::range_basic> {
15151536
using Name = __sycl_reduction_kernel<reduction::MainKrn, KernelName,
15161537
reduction::strategy::range_basic>;
15171538

1539+
WGSize = std::min(WGSize, reduGetPreferredKernelWGSize<Name>(Queue));
1540+
15181541
CGH.parallel_for<Name>(NDRange, Properties, [=](nd_item<1> NDId) {
15191542
// Call user's functions. Reducer.MValue gets initialized there.
15201543
reducer_type Reducer = reducer_type(IdentityContainer, BOp);
@@ -1628,14 +1651,14 @@ struct NDRangeReduction<
16281651
using reducer_type = typename Reduction::reducer_type;
16291652
using element_type = typename ReducerTraits<reducer_type>::element_type;
16301653

1631-
std::ignore = Queue;
16321654
using Name = __sycl_reduction_kernel<
16331655
reduction::MainKrn, KernelName,
16341656
reduction::strategy::local_mem_tree_and_atomic_cross_wg>;
16351657
Redu.template withInitializedMem<Name>(CGH, [&](auto Out) {
16361658
size_t NElements = Reduction::num_elements;
16371659
size_t WGSize = NDRange.get_local_range().size();
16381660

1661+
WGSize = std::min(WGSize, reduGetPreferredKernelWGSize<Name>(Queue));
16391662
// Use local memory to reduce elements in work-groups into zero-th
16401663
// element.
16411664
local_accessor<element_type, 1> LocalReds{WGSize, CGH};
@@ -1722,6 +1745,8 @@ struct NDRangeReduction<
17221745
reduction::MainKrn, KernelName,
17231746
reduction::strategy::group_reduce_and_multiple_kernels>;
17241747

1748+
MaxWGSize = std::min(MaxWGSize, reduGetPreferredKernelWGSize<Name>(Queue));
1749+
17251750
CGH.parallel_for<Name>(NDRange, Properties, [=](nd_item<Dims> NDIt) {
17261751
// Call user's functions. Reducer.MValue gets initialized there.
17271752
typename Reduction::reducer_type Reducer;
@@ -1781,6 +1806,8 @@ struct NDRangeReduction<
17811806
reduction::AuxKrn, KernelName,
17821807
reduction::strategy::group_reduce_and_multiple_kernels>;
17831808

1809+
WGSize = std::min(WGSize, reduGetPreferredKernelWGSize<Name>(Queue));
1810+
17841811
bool IsUpdateOfUserVar = !Reduction::is_usm &&
17851812
!Redu.initializeToIdentity() &&
17861813
NWorkGroups == 1;
@@ -1874,6 +1901,9 @@ template <> struct NDRangeReduction<reduction::strategy::basic> {
18741901
reduction::strategy::basic,
18751902
decltype(KernelTag)>;
18761903

1904+
MaxWGSize =
1905+
std::min(MaxWGSize, reduGetPreferredKernelWGSize<Name>(Queue));
1906+
18771907
CGH.parallel_for<Name>(NDRange, Properties, [=](nd_item<Dims> NDIt) {
18781908
// Call user's functions. Reducer.MValue gets initialized there.
18791909
typename Reduction::reducer_type Reducer =
@@ -1978,6 +2008,8 @@ template <> struct NDRangeReduction<reduction::strategy::basic> {
19782008
reduction::strategy::basic,
19792009
decltype(KernelTag)>;
19802010

2011+
WGSize = std::min(WGSize, reduGetPreferredKernelWGSize<Name>(Queue));
2012+
19812013
range<1> GlobalRange = {UniformPow2WG ? NWorkItems
19822014
: NWorkGroups * WGSize};
19832015
nd_range<1> Range{GlobalRange, range<1>(WGSize)};
@@ -2295,8 +2327,9 @@ template <class KernelName, class Accessor> struct NDRangeMulti;
22952327
} // namespace reduction::main_krn
22962328
template <typename KernelName, typename KernelType, int Dims,
22972329
typename PropertiesT, typename... Reductions, size_t... Is>
2298-
void reduCGFuncMulti(handler &CGH, KernelType KernelFunc,
2299-
const nd_range<Dims> &Range, PropertiesT Properties,
2330+
void reduCGFuncMulti(handler &CGH, std::shared_ptr<detail::queue_impl> &Queue,
2331+
KernelType KernelFunc, const nd_range<Dims> &Range,
2332+
PropertiesT Properties,
23002333
std::tuple<Reductions...> &ReduTuple,
23012334
std::index_sequence<Is...> ReduIndices) {
23022335
size_t WGSize = Range.get_local_range().size();
@@ -2334,6 +2367,8 @@ void reduCGFuncMulti(handler &CGH, KernelType KernelFunc,
23342367
reduction::strategy::multi,
23352368
decltype(KernelTag)>;
23362369

2370+
WGSize = std::min(WGSize, reduGetPreferredKernelWGSize<Name>(Queue));
2371+
23372372
CGH.parallel_for<Name>(Range, Properties, [=](nd_item<Dims> NDIt) {
23382373
// We can deduce IsOneWG from the tag type.
23392374
constexpr bool IsOneWG =
@@ -2495,7 +2530,8 @@ template <class KernelName, class Predicate> struct Multi;
24952530
} // namespace reduction::aux_krn
24962531
template <typename KernelName, typename KernelType, typename... Reductions,
24972532
size_t... Is>
2498-
size_t reduAuxCGFunc(handler &CGH, size_t NWorkItems, size_t MaxWGSize,
2533+
size_t reduAuxCGFunc(handler &CGH, std::shared_ptr<queue_impl> &Queue,
2534+
size_t NWorkItems, size_t MaxWGSize,
24992535
std::tuple<Reductions...> &ReduTuple,
25002536
std::index_sequence<Is...> ReduIndices) {
25012537
size_t NWorkGroups;
@@ -2533,6 +2569,8 @@ size_t reduAuxCGFunc(handler &CGH, size_t NWorkItems, size_t MaxWGSize,
25332569
using Name = __sycl_reduction_kernel<reduction::AuxKrn, KernelName,
25342570
reduction::strategy::multi,
25352571
decltype(Predicate)>;
2572+
WGSize = std::min(WGSize, reduGetPreferredKernelWGSize<Name>(Queue));
2573+
25362574
// TODO: Opportunity to parallelize across number of elements
25372575
range<1> GlobalRange = {HasUniformWG ? NWorkItems : NWorkGroups * WGSize};
25382576
nd_range<1> Range{GlobalRange, range<1>(WGSize)};
@@ -2617,15 +2655,15 @@ template <> struct NDRangeReduction<reduction::strategy::multi> {
26172655
" than " +
26182656
std::to_string(MaxWGSize));
26192657

2620-
reduCGFuncMulti<KernelName>(CGH, KernelFunc, NDRange, Properties, ReduTuple,
2621-
ReduIndices);
2658+
reduCGFuncMulti<KernelName>(CGH, Queue, KernelFunc, NDRange, Properties,
2659+
ReduTuple, ReduIndices);
26222660
reduction::finalizeHandler(CGH);
26232661

26242662
size_t NWorkItems = NDRange.get_group_range().size();
26252663
while (NWorkItems > 1) {
26262664
reduction::withAuxHandler(CGH, [&](handler &AuxHandler) {
26272665
NWorkItems = reduAuxCGFunc<KernelName, decltype(KernelFunc)>(
2628-
AuxHandler, NWorkItems, MaxWGSize, ReduTuple, ReduIndices);
2666+
AuxHandler, Queue, NWorkItems, MaxWGSize, ReduTuple, ReduIndices);
26292667
});
26302668
} // end while (NWorkItems > 1)
26312669
}
@@ -2741,7 +2779,29 @@ void reduction_parallel_for(handler &CGH, range<Dims> Range,
27412779
// TODO: currently the preferred work group size is determined for the given
27422780
// queue/device, while it is safer to use queries to the kernel pre-compiled
27432781
// for the device.
2744-
size_t PrefWGSize = reduGetPreferredWGSize(CGH.MQueue, OneElemSize);
2782+
size_t PrefWGSize = reduGetPreferredDeviceWGSize(CGH.MQueue, OneElemSize);
2783+
2784+
auto SyclQueue = createSyclObjFromImpl<queue>(CGH.MQueue);
2785+
auto Ctx = SyclQueue.get_context();
2786+
auto Dev = SyclQueue.get_device();
2787+
2788+
// If the reduction kernel is not name defined, we won't be able to query the
2789+
// exact kernel for the best wgsize, so we query all the reduction kernels for
2790+
// thier wgsize and use the minimum wgsize as a safe and approximate option.
2791+
constexpr bool IsUndefinedKernelName{std::is_same_v<KernelName, auto_name>};
2792+
if (IsUndefinedKernelName) {
2793+
std::vector<kernel_id> ReductionKernelIDs = get_kernel_ids();
2794+
for (auto KernelID : ReductionKernelIDs) {
2795+
std::string ReduKernelName = KernelID.get_name();
2796+
if (ReduKernelName.find("reduction") != std::string::npos) {
2797+
auto KB = get_kernel_bundle<bundle_state::executable>(Ctx, {KernelID});
2798+
kernel krn = KB.get_kernel(KernelID);
2799+
using namespace info::kernel_device_specific;
2800+
size_t MaxSize = krn.template get_info<work_group_size>(Dev);
2801+
PrefWGSize = std::min(PrefWGSize, MaxSize);
2802+
}
2803+
}
2804+
}
27452805

27462806
size_t NWorkItems = Range.size();
27472807
size_t WGSize = std::min(NWorkItems, PrefWGSize);

sycl/source/detail/reduction.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,8 @@ reduGetMaxWGSize(std::shared_ptr<sycl::detail::queue_impl> Queue,
113113
return WGSize;
114114
}
115115

116-
__SYCL_EXPORT size_t reduGetPreferredWGSize(std::shared_ptr<queue_impl> &Queue,
117-
size_t LocalMemBytesPerWorkItem) {
116+
__SYCL_EXPORT size_t reduGetPreferredDeviceWGSize(
117+
std::shared_ptr<queue_impl> &Queue, size_t LocalMemBytesPerWorkItem) {
118118
// TODO: Graphs extension explicit API uses a handler with a null queue to
119119
// process CGFs, in future we should have access to the device so we can
120120
// correctly calculate this.

sycl/test/abi/sycl_symbols_linux.dump

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3294,7 +3294,7 @@ _ZN4sycl3_V16detail22get_kernel_bundle_implERKNS0_7contextERKSt6vectorINS0_6devi
32943294
_ZN4sycl3_V16detail22get_kernel_bundle_implERKNS0_7contextERKSt6vectorINS0_6deviceESaIS6_EERKS5_INS0_9kernel_idESaISB_EENS0_12bundle_stateE
32953295
_ZN4sycl3_V16detail22has_kernel_bundle_implERKNS0_7contextERKSt6vectorINS0_6deviceESaIS6_EENS0_12bundle_stateE
32963296
_ZN4sycl3_V16detail22has_kernel_bundle_implERKNS0_7contextERKSt6vectorINS0_6deviceESaIS6_EERKS5_INS0_9kernel_idESaISB_EENS0_12bundle_stateE
3297-
_ZN4sycl3_V16detail22reduGetPreferredWGSizeERSt10shared_ptrINS1_10queue_implEEm
3297+
_ZN4sycl3_V16detail28reduGetPreferredDeviceWGSizeERSt10shared_ptrINS1_10queue_implEEm
32983298
_ZN4sycl3_V16detail22removeDuplicateDevicesERKSt6vectorINS0_6deviceESaIS3_EE
32993299
_ZN4sycl3_V16detail23constructorNotificationEPvS2_NS0_6access6targetENS3_4modeERKNS1_13code_locationE
33003300
_ZN4sycl3_V16detail24find_device_intersectionERKSt6vectorINS0_13kernel_bundleILNS0_12bundle_stateE1EEESaIS5_EE

sycl/test/abi/sycl_symbols_windows.dump

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4206,7 +4206,7 @@
42064206
?reduComputeWGSize@detail@_V1@sycl@@YA_K_K0AEA_K@Z
42074207
?reduGetMaxNumConcurrentWorkGroups@detail@_V1@sycl@@YAIV?$shared_ptr@Vqueue_impl@detail@_V1@sycl@@@std@@@Z
42084208
?reduGetMaxWGSize@detail@_V1@sycl@@YA_KV?$shared_ptr@Vqueue_impl@detail@_V1@sycl@@@std@@_K@Z
4209-
?reduGetPreferredWGSize@detail@_V1@sycl@@YA_KAEAV?$shared_ptr@Vqueue_impl@detail@_V1@sycl@@@std@@_K@Z
4209+
?reduGetPreferredDeviceWGSize@detail@_V1@sycl@@YA_KAEAV?$shared_ptr@Vqueue_impl@detail@_V1@sycl@@@std@@_K@Z
42104210
?registerDynamicParameter@handler@_V1@sycl@@AEAAXAEAVdynamic_parameter_base@detail@experimental@oneapi@ext@23@H@Z
42114211
?release_external_memory@experimental@oneapi@ext@_V1@sycl@@YAXUexternal_mem@12345@AEBVdevice@45@AEBVcontext@45@@Z
42124212
?release_external_memory@experimental@oneapi@ext@_V1@sycl@@YAXUexternal_mem@12345@AEBVqueue@45@@Z

0 commit comments

Comments
 (0)