Skip to content

Commit b4908d0

Browse files
committed
Move MKernelUsesClusterLaunch and MNDRDesc to KernelData
1 parent 06d4485 commit b4908d0

File tree

6 files changed

+93
-80
lines changed

6 files changed

+93
-80
lines changed

sycl/source/detail/handler_impl.hpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,6 @@ class handler_impl {
107107
ur_kernel_cache_config_t MKernelCacheConfig = UR_KERNEL_CACHE_CONFIG_DEFAULT;
108108

109109
bool MKernelIsCooperative = false;
110-
bool MKernelUsesClusterLaunch = false;
111110
uint32_t MKernelWorkGroupMemorySize = 0;
112111

113112
// Extra information for bindless image copy
@@ -145,9 +144,6 @@ class handler_impl {
145144
/// have become required for this handler via require method.
146145
std::vector<detail::ArgDesc> MAssociatedAccesors;
147146

148-
/// Struct that encodes global size, local size, ...
149-
detail::NDRDescT MNDRDesc;
150-
151147
/// Type of the command group, e.g. kernel, fill. Can also encode version.
152148
/// Use getType and setType methods to access this variable unless
153149
/// manipulations with version are required

sycl/source/detail/kernel_data.cpp

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -103,12 +103,12 @@ static void addArgsForLocalAccessor(detail::LocalAccessorImplHost *LAcc,
103103

104104
void KernelData::processArg(void *Ptr, const detail::kernel_param_kind_t &Kind,
105105
const int Size, const size_t Index,
106-
size_t &IndexShift, bool IsKernelCreatedFromSource,
107-
const NDRDescT &NDRDesc) {
106+
size_t &IndexShift,
107+
bool IsKernelCreatedFromSource) {
108108
using detail::kernel_param_kind_t;
109-
size_t GlobalSize = NDRDesc.GlobalSize[0];
110-
for (size_t I = 1; I < NDRDesc.Dims; ++I) {
111-
GlobalSize *= NDRDesc.GlobalSize[I];
109+
size_t GlobalSize = MNDRDesc.GlobalSize[0];
110+
for (size_t I = 1; I < MNDRDesc.Dims; ++I) {
111+
GlobalSize *= MNDRDesc.GlobalSize[I];
112112
}
113113

114114
switch (Kind) {
@@ -145,9 +145,9 @@ void KernelData::processArg(void *Ptr, const detail::kernel_param_kind_t &Kind,
145145
// So we just suppose that WG size is always default for stream.
146146
// TODO adjust MNDRDesc when device image contains kernel's attribute
147147
if (GlobalSize == 0) {
148-
GlobalSize = NDRDesc.NumWorkGroups[0];
149-
for (size_t I = 1; I < NDRDesc.Dims; ++I) {
150-
GlobalSize *= NDRDesc.NumWorkGroups[I];
148+
GlobalSize = MNDRDesc.NumWorkGroups[0];
149+
for (size_t I = 1; I < MNDRDesc.Dims; ++I) {
150+
GlobalSize *= MNDRDesc.NumWorkGroups[I];
151151
}
152152
}
153153
addArgsForGlobalAccessor(GFlushReq, Index, IndexShift, Size,
@@ -271,8 +271,7 @@ void KernelData::processArg(void *Ptr, const detail::kernel_param_kind_t &Kind,
271271
}
272272
}
273273

274-
void KernelData::extractArgsAndReqs(const NDRDescT &NDRDesc,
275-
bool IsKernelCreatedFromSource) {
274+
void KernelData::extractArgsAndReqs(bool IsKernelCreatedFromSource) {
276275
std::vector<detail::ArgDesc> UnPreparedArgs = std::move(MArgs);
277276
clearArgs();
278277

@@ -290,12 +289,11 @@ void KernelData::extractArgsAndReqs(const NDRDescT &NDRDesc,
290289
const detail::kernel_param_kind_t &Kind = UnPreparedArgs[I].MType;
291290
const int &Size = UnPreparedArgs[I].MSize;
292291
const int Index = UnPreparedArgs[I].MIndex;
293-
processArg(Ptr, Kind, Size, Index, IndexShift, IsKernelCreatedFromSource,
294-
NDRDesc);
292+
processArg(Ptr, Kind, Size, Index, IndexShift, IsKernelCreatedFromSource);
295293
}
296294
}
297295

298-
void KernelData::extractArgsAndReqsFromLambda(const NDRDescT &NDRDesc) {
296+
void KernelData::extractArgsAndReqsFromLambda() {
299297
size_t IndexShift = 0;
300298
clearArgs();
301299
MArgs.reserve(MaxNumAdditionalArgs * getKernelNumArgs());
@@ -343,7 +341,7 @@ void KernelData::extractArgsAndReqsFromLambda(const NDRDescT &NDRDesc) {
343341
}
344342

345343
processArg(Ptr, Kind, Size, I, IndexShift,
346-
/*IsKernelCreatedFromSource=*/false, NDRDesc);
344+
/*IsKernelCreatedFromSource=*/false);
347345
}
348346
}
349347

sycl/source/detail/kernel_data.hpp

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,14 @@ class KernelData {
6060

6161
void clearArgs() { MArgs.clear(); }
6262

63+
detail::NDRDescT &getNDRDesc() & { return MNDRDesc; }
64+
65+
const detail::NDRDescT &getNDRDesc() const & { return MNDRDesc; }
66+
67+
detail::NDRDescT &&getNDRDesc() && { return std::move(MNDRDesc); }
68+
69+
void setNDRDesc(const detail::NDRDescT &NDRDesc) { MNDRDesc = NDRDesc; }
70+
6371
void *getKernelFuncPtr() const { return MKernelFuncPtr; }
6472

6573
size_t getKernelNumArgs() const { return MDeviceKernelInfoPtr->NumParams; }
@@ -107,14 +115,20 @@ class KernelData {
107115

108116
bool usesAssert() const { return MDeviceKernelInfoPtr->usesAssert(); }
109117

118+
bool usesClusterLaunch() const { return MKernelUsesClusterLaunch; }
119+
120+
template <int Dims_> void setClusterDimensions(sycl::range<Dims_> N) {
121+
MKernelUsesClusterLaunch = true;
122+
MNDRDesc.setClusterDimensions(N);
123+
}
124+
110125
void processArg(void *Ptr, const detail::kernel_param_kind_t &Kind,
111126
const int Size, const size_t Index, size_t &IndexShift,
112-
bool IsKernelCreatedFromSource, const NDRDescT &NDRDesc);
127+
bool IsKernelCreatedFromSource);
113128

114-
void extractArgsAndReqs(const NDRDescT &NDRDesc,
115-
bool IsKernelCreatedFromSource);
129+
void extractArgsAndReqs(bool IsKernelCreatedFromSource);
116130

117-
void extractArgsAndReqsFromLambda(const NDRDescT &NDRDesc);
131+
void extractArgsAndReqsFromLambda();
118132

119133
private:
120134
// Storage for any SYCL Graph dynamic parameters which have been flagged for
@@ -124,6 +138,11 @@ class KernelData {
124138
/// The list of arguments for the kernel.
125139
std::vector<detail::ArgDesc> MArgs;
126140

141+
bool MKernelUsesClusterLaunch = false;
142+
143+
/// Struct that encodes global size, local size, ...
144+
detail::NDRDescT MNDRDesc;
145+
127146
// Store information about the kernel arguments.
128147
void *MKernelFuncPtr = nullptr;
129148

sycl/source/handler.cpp

Lines changed: 50 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,7 @@ event handler::finalize() {
498498
// Skipping this is currently limited to simple kernels on the fast path.
499499
if (type == detail::CGType::Kernel && impl->MKernelData.getKernelFuncPtr() &&
500500
(!KernelFastPath || impl->MKernelData.hasSpecialCaptures())) {
501-
impl->MKernelData.extractArgsAndReqsFromLambda(impl->MNDRDesc);
501+
impl->MKernelData.extractArgsAndReqsFromLambda();
502502
}
503503

504504
// According to 4.7.6.9 of SYCL2020 spec, if a placeholder accessor is passed
@@ -646,8 +646,8 @@ event handler::finalize() {
646646
std::tie(CmdTraceEvent, InstanceID) = emitKernelInstrumentationData(
647647
detail::GSYCLStreamID, MKernel, MCodeLoc, impl->MIsTopCodeLoc,
648648
MKernelName.data(), *impl->MKernelData.getDeviceKernelInfoPtr(),
649-
impl->get_queue_or_null(), impl->MNDRDesc, KernelBundleImpPtr,
650-
impl->MKernelData.getArgs());
649+
impl->get_queue_or_null(), impl->MKernelData.getNDRDesc(),
650+
KernelBundleImpPtr, impl->MKernelData.getArgs());
651651
detail::emitInstrumentationGeneral(detail::GSYCLStreamID, InstanceID,
652652
CmdTraceEvent,
653653
xpti::trace_task_begin, nullptr);
@@ -659,17 +659,18 @@ event handler::finalize() {
659659
impl->get_queue(), toKernelNameStrT(MKernelName));
660660
assert(BinImage && "Failed to obtain a binary image.");
661661
}
662-
enqueueImpKernel(
663-
impl->get_queue(), impl->MNDRDesc, impl->MKernelData.getArgs(),
664-
KernelBundleImpPtr, MKernel.get(), toKernelNameStrT(MKernelName),
665-
*impl->MKernelData.getDeviceKernelInfoPtr(), RawEvents,
666-
ResultEvent.get(), nullptr, impl->MKernelCacheConfig,
667-
impl->MKernelIsCooperative, impl->MKernelUsesClusterLaunch,
668-
impl->MKernelWorkGroupMemorySize, BinImage,
669-
impl->MKernelData.getKernelFuncPtr(),
670-
impl->MKernelData.getKernelNumArgs(),
671-
impl->MKernelData.getKernelParamDescGetter(),
672-
impl->MKernelData.hasSpecialCaptures());
662+
enqueueImpKernel(impl->get_queue(), impl->MKernelData.getNDRDesc(),
663+
impl->MKernelData.getArgs(), KernelBundleImpPtr,
664+
MKernel.get(), toKernelNameStrT(MKernelName),
665+
*impl->MKernelData.getDeviceKernelInfoPtr(), RawEvents,
666+
ResultEvent.get(), nullptr, impl->MKernelCacheConfig,
667+
impl->MKernelIsCooperative,
668+
impl->MKernelData.usesClusterLaunch(),
669+
impl->MKernelWorkGroupMemorySize, BinImage,
670+
impl->MKernelData.getKernelFuncPtr(),
671+
impl->MKernelData.getKernelNumArgs(),
672+
impl->MKernelData.getKernelParamDescGetter(),
673+
impl->MKernelData.hasSpecialCaptures());
673674
#ifdef XPTI_ENABLE_INSTRUMENTATION
674675
if (xptiEnabled) {
675676
// Emit signal only when event is created
@@ -724,13 +725,14 @@ event handler::finalize() {
724725
// assert feature to check if kernel uses assertions
725726
#endif
726727
CommandGroup.reset(new detail::CGExecKernel(
727-
impl->MNDRDesc, std::move(MHostKernel), std::move(MKernel),
728-
std::move(impl->MKernelBundle), std::move(impl->CGData),
729-
std::move(impl->MKernelData).getArgs(), toKernelNameStrT(MKernelName),
728+
impl->MKernelData.getNDRDesc(), std::move(MHostKernel),
729+
std::move(MKernel), std::move(impl->MKernelBundle),
730+
std::move(impl->CGData), std::move(impl->MKernelData).getArgs(),
731+
toKernelNameStrT(MKernelName),
730732
*impl->MKernelData.getDeviceKernelInfoPtr(), std::move(MStreamStorage),
731733
std::move(impl->MAuxiliaryResources), getType(),
732734
impl->MKernelCacheConfig, impl->MKernelIsCooperative,
733-
impl->MKernelUsesClusterLaunch, impl->MKernelWorkGroupMemorySize,
735+
impl->MKernelData.usesClusterLaunch(), impl->MKernelWorkGroupMemorySize,
734736
MCodeLoc));
735737
break;
736738
}
@@ -1060,7 +1062,7 @@ void handler::processArg(void *Ptr, const detail::kernel_param_kind_t &Kind,
10601062
bool IsKernelCreatedFromSource, bool IsESIMD) {
10611063
(void)IsESIMD;
10621064
impl->MKernelData.processArg(Ptr, Kind, Size, Index, IndexShift,
1063-
IsKernelCreatedFromSource, impl->MNDRDesc);
1065+
IsKernelCreatedFromSource);
10641066
}
10651067
#endif
10661068

@@ -1079,8 +1081,7 @@ void handler::setArgHelper(int ArgIndex, stream &&Str) {
10791081

10801082
void handler::extractArgsAndReqs() {
10811083
assert(MKernel && "MKernel is not initialized");
1082-
impl->MKernelData.extractArgsAndReqs(impl->MNDRDesc,
1083-
MKernel->isCreatedFromSource());
1084+
impl->MKernelData.extractArgsAndReqs(MKernel->isCreatedFromSource());
10841085
}
10851086

10861087
#ifndef __INTEL_PREVIEW_BREAKING_CHANGES
@@ -1093,7 +1094,7 @@ void handler::extractArgsAndReqsFromLambda(
10931094
(void)ParamDescGetter;
10941095
(void)NumKernelParams;
10951096
(void)IsESIMD;
1096-
impl->MKernelData.extractArgsAndReqsFromLambda(impl->MNDRDesc);
1097+
impl->MKernelData.extractArgsAndReqsFromLambda();
10971098
}
10981099

10991100
void handler::extractArgsAndReqsFromLambda(
@@ -1126,7 +1127,7 @@ void handler::extractArgsAndReqsFromLambda(
11261127
}
11271128
}
11281129
impl->MKernelData.processArg(Ptr, Kind, Size, I, IndexShift,
1129-
IsKernelCreatedFromSource, impl->MNDRDesc);
1130+
IsKernelCreatedFromSource);
11301131
}
11311132
}
11321133

@@ -2030,16 +2031,15 @@ void handler::setKernelClusterLaunch(sycl::range<3> ClusterSize, int Dims) {
20302031
throwIfGraphAssociated<
20312032
syclex::detail::UnsupportedGraphFeatures::
20322033
sycl_ext_oneapi_experimental_cuda_cluster_launch>();
2033-
impl->MKernelUsesClusterLaunch = true;
20342034

20352035
if (Dims == 1) {
20362036
sycl::range<1> ClusterSizeTrimmed = {ClusterSize[0]};
2037-
impl->MNDRDesc.setClusterDimensions(ClusterSizeTrimmed);
2037+
impl->MKernelData.setClusterDimensions(ClusterSizeTrimmed);
20382038
} else if (Dims == 2) {
20392039
sycl::range<2> ClusterSizeTrimmed = {ClusterSize[0], ClusterSize[1]};
2040-
impl->MNDRDesc.setClusterDimensions(ClusterSizeTrimmed);
2040+
impl->MKernelData.setClusterDimensions(ClusterSizeTrimmed);
20412041
} else if (Dims == 3) {
2042-
impl->MNDRDesc.setClusterDimensions(ClusterSize);
2042+
impl->MKernelData.setClusterDimensions(ClusterSize);
20432043
}
20442044
}
20452045
#endif
@@ -2048,24 +2048,21 @@ void handler::setKernelClusterLaunch(sycl::range<3> ClusterSize) {
20482048
throwIfGraphAssociated<
20492049
syclex::detail::UnsupportedGraphFeatures::
20502050
sycl_ext_oneapi_experimental_cuda_cluster_launch>();
2051-
impl->MKernelUsesClusterLaunch = true;
2052-
impl->MNDRDesc.setClusterDimensions(ClusterSize);
2051+
impl->MKernelData.setClusterDimensions(ClusterSize);
20532052
}
20542053

20552054
void handler::setKernelClusterLaunch(sycl::range<2> ClusterSize) {
20562055
throwIfGraphAssociated<
20572056
syclex::detail::UnsupportedGraphFeatures::
20582057
sycl_ext_oneapi_experimental_cuda_cluster_launch>();
2059-
impl->MKernelUsesClusterLaunch = true;
2060-
impl->MNDRDesc.setClusterDimensions(ClusterSize);
2058+
impl->MKernelData.setClusterDimensions(ClusterSize);
20612059
}
20622060

20632061
void handler::setKernelClusterLaunch(sycl::range<1> ClusterSize) {
20642062
throwIfGraphAssociated<
20652063
syclex::detail::UnsupportedGraphFeatures::
20662064
sycl_ext_oneapi_experimental_cuda_cluster_launch>();
2067-
impl->MKernelUsesClusterLaunch = true;
2068-
impl->MNDRDesc.setClusterDimensions(ClusterSize);
2065+
impl->MKernelData.setClusterDimensions(ClusterSize);
20692066
}
20702067

20712068
void handler::setKernelWorkGroupMem(size_t Size) {
@@ -2227,12 +2224,12 @@ void handler::setNDRangeDescriptorPadded(sycl::range<3> N,
22272224
bool SetNumWorkGroups, int Dims) {
22282225
if (Dims == 1) {
22292226
sycl::range<1> Range = {N[0]};
2230-
impl->MNDRDesc = NDRDescT{Range, SetNumWorkGroups};
2227+
impl->MKernelData.setNDRDesc(NDRDescT{Range, SetNumWorkGroups});
22312228
} else if (Dims == 2) {
22322229
sycl::range<2> Range = {N[0], N[1]};
2233-
impl->MNDRDesc = NDRDescT{Range, SetNumWorkGroups};
2230+
impl->MKernelData.setNDRDesc(NDRDescT{Range, SetNumWorkGroups});
22342231
} else if (Dims == 3) {
2235-
impl->MNDRDesc = NDRDescT{N, SetNumWorkGroups};
2232+
impl->MKernelData.setNDRDesc(NDRDescT{N, SetNumWorkGroups});
22362233
}
22372234
}
22382235

@@ -2241,13 +2238,13 @@ void handler::setNDRangeDescriptorPadded(sycl::range<3> NumWorkItems,
22412238
if (Dims == 1) {
22422239
sycl::range<1> NumWorkItemsTrimmed = {NumWorkItems[0]};
22432240
sycl::id<1> OffsetTrimmed = {Offset[0]};
2244-
impl->MNDRDesc = NDRDescT{NumWorkItemsTrimmed, OffsetTrimmed};
2241+
impl->MKernelData.setNDRDesc(NDRDescT{NumWorkItemsTrimmed, OffsetTrimmed});
22452242
} else if (Dims == 2) {
22462243
sycl::range<2> NumWorkItemsTrimmed = {NumWorkItems[0], NumWorkItems[1]};
22472244
sycl::id<2> OffsetTrimmed = {Offset[0], Offset[1]};
2248-
impl->MNDRDesc = NDRDescT{NumWorkItemsTrimmed, OffsetTrimmed};
2245+
impl->MKernelData.setNDRDesc(NDRDescT{NumWorkItemsTrimmed, OffsetTrimmed});
22492246
} else if (Dims == 3) {
2250-
impl->MNDRDesc = NDRDescT{NumWorkItems, Offset};
2247+
impl->MKernelData.setNDRDesc(NDRDescT{NumWorkItems, Offset});
22512248
}
22522249
}
22532250

@@ -2258,57 +2255,57 @@ void handler::setNDRangeDescriptorPadded(sycl::range<3> NumWorkItems,
22582255
sycl::range<1> NumWorkItemsTrimmed = {NumWorkItems[0]};
22592256
sycl::range<1> LocalSizeTrimmed = {LocalSize[0]};
22602257
sycl::id<1> OffsetTrimmed = {Offset[0]};
2261-
impl->MNDRDesc =
2262-
NDRDescT{NumWorkItemsTrimmed, LocalSizeTrimmed, OffsetTrimmed};
2258+
impl->MKernelData.setNDRDesc(
2259+
NDRDescT{NumWorkItemsTrimmed, LocalSizeTrimmed, OffsetTrimmed});
22632260
} else if (Dims == 2) {
22642261
sycl::range<2> NumWorkItemsTrimmed = {NumWorkItems[0], NumWorkItems[1]};
22652262
sycl::range<2> LocalSizeTrimmed = {LocalSize[0], LocalSize[1]};
22662263
sycl::id<2> OffsetTrimmed = {Offset[0], Offset[1]};
2267-
impl->MNDRDesc =
2268-
NDRDescT{NumWorkItemsTrimmed, LocalSizeTrimmed, OffsetTrimmed};
2264+
impl->MKernelData.setNDRDesc(
2265+
NDRDescT{NumWorkItemsTrimmed, LocalSizeTrimmed, OffsetTrimmed});
22692266
} else if (Dims == 3) {
2270-
impl->MNDRDesc = NDRDescT{NumWorkItems, LocalSize, Offset};
2267+
impl->MKernelData.setNDRDesc(NDRDescT{NumWorkItems, LocalSize, Offset});
22712268
}
22722269
}
22732270
#endif
22742271

22752272
void handler::setNDRangeDescriptor(sycl::range<3> N, bool SetNumWorkGroups) {
2276-
impl->MNDRDesc = NDRDescT{N, SetNumWorkGroups};
2273+
impl->MKernelData.setNDRDesc(NDRDescT{N, SetNumWorkGroups});
22772274
}
22782275
void handler::setNDRangeDescriptor(sycl::range<3> NumWorkItems,
22792276
sycl::id<3> Offset) {
2280-
impl->MNDRDesc = NDRDescT{NumWorkItems, Offset};
2277+
impl->MKernelData.setNDRDesc(NDRDescT{NumWorkItems, Offset});
22812278
}
22822279
void handler::setNDRangeDescriptor(sycl::range<3> NumWorkItems,
22832280
sycl::range<3> LocalSize,
22842281
sycl::id<3> Offset) {
2285-
impl->MNDRDesc = NDRDescT{NumWorkItems, LocalSize, Offset};
2282+
impl->MKernelData.setNDRDesc(NDRDescT{NumWorkItems, LocalSize, Offset});
22862283
}
22872284

22882285
void handler::setNDRangeDescriptor(sycl::range<2> N, bool SetNumWorkGroups) {
2289-
impl->MNDRDesc = NDRDescT{N, SetNumWorkGroups};
2286+
impl->MKernelData.setNDRDesc(NDRDescT{N, SetNumWorkGroups});
22902287
}
22912288
void handler::setNDRangeDescriptor(sycl::range<2> NumWorkItems,
22922289
sycl::id<2> Offset) {
2293-
impl->MNDRDesc = NDRDescT{NumWorkItems, Offset};
2290+
impl->MKernelData.setNDRDesc(NDRDescT{NumWorkItems, Offset});
22942291
}
22952292
void handler::setNDRangeDescriptor(sycl::range<2> NumWorkItems,
22962293
sycl::range<2> LocalSize,
22972294
sycl::id<2> Offset) {
2298-
impl->MNDRDesc = NDRDescT{NumWorkItems, LocalSize, Offset};
2295+
impl->MKernelData.setNDRDesc(NDRDescT{NumWorkItems, LocalSize, Offset});
22992296
}
23002297

23012298
void handler::setNDRangeDescriptor(sycl::range<1> N, bool SetNumWorkGroups) {
2302-
impl->MNDRDesc = NDRDescT{N, SetNumWorkGroups};
2299+
impl->MKernelData.setNDRDesc(NDRDescT{N, SetNumWorkGroups});
23032300
}
23042301
void handler::setNDRangeDescriptor(sycl::range<1> NumWorkItems,
23052302
sycl::id<1> Offset) {
2306-
impl->MNDRDesc = NDRDescT{NumWorkItems, Offset};
2303+
impl->MKernelData.setNDRDesc(NDRDescT{NumWorkItems, Offset});
23072304
}
23082305
void handler::setNDRangeDescriptor(sycl::range<1> NumWorkItems,
23092306
sycl::range<1> LocalSize,
23102307
sycl::id<1> Offset) {
2311-
impl->MNDRDesc = NDRDescT{NumWorkItems, LocalSize, Offset};
2308+
impl->MKernelData.setNDRDesc(NDRDescT{NumWorkItems, LocalSize, Offset});
23122309
}
23132310

23142311
#ifndef __INTEL_PREVIEW_BREAKING_CHANGES

0 commit comments

Comments
 (0)