Skip to content

Commit 634da82

Browse files
[SYCL] Postpone creation of HostKernel copy
Do not create copy of HostKernel till it became used out of submit stack.
1 parent eb415c4 commit 634da82

File tree

2 files changed

+36
-11
lines changed

2 files changed

+36
-11
lines changed

sycl/include/sycl/queue.hpp

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,17 @@ class __SYCL_EXPORT SubmissionInfo {
151151
ext::oneapi::experimental::event_mode_enum::none;
152152
};
153153

154+
template<typename KernelType, typename TransformedArgType, int Dims>
155+
std::shared_ptr<detail::HostKernelBase> CopyHostKernel(const void *KernelFunc) {
156+
const KernelType &KernelFuncRef = *static_cast<const KernelType*>(KernelFunc);
157+
std::shared_ptr<detail::HostKernelBase> HostKernel;
158+
HostKernel.reset(new detail::HostKernel<KernelType, TransformedArgType, Dims>(
159+
KernelFuncRef));
160+
return HostKernel;
161+
}
162+
163+
using HostKernelFactory = std::shared_ptr<detail::HostKernelBase>(*)(const void*);
164+
154165
using KernelParamDescGetterFuncPtr = detail::kernel_param_desc_t (*)(int);
155166

156167
// This class is intended to store the kernel runtime information,
@@ -172,13 +183,20 @@ class __SYCL_EXPORT KernelRuntimeInfo {
172183
return MKernelName;
173184
}
174185

175-
std::shared_ptr<detail::HostKernelBase> &HostKernel() { return MHostKernel; }
176-
const std::shared_ptr<detail::HostKernelBase> &HostKernel() const {
177-
return MHostKernel;
186+
char *GetKernelFuncPtr() const {
187+
return static_cast<char *>(const_cast<void *>(MHostKernelPtr));
178188
}
179189

180-
char *GetKernelFuncPtr() { return (*MHostKernel).getPtr(); }
181-
char *GetKernelFuncPtr() const { return (*MHostKernel).getPtr(); }
190+
void SaveHostKernelRef(const void *KernelFuncPtr, HostKernelFactory Factory) {
191+
MHostKernelFactory = Factory;
192+
MHostKernelPtr = KernelFuncPtr;
193+
}
194+
195+
std::shared_ptr<detail::HostKernelBase> CopyHostKernel() const {
196+
if (MHostKernelFactory && MHostKernelPtr)
197+
return MHostKernelFactory(MHostKernelPtr);
198+
return nullptr;
199+
}
182200

183201
detail::DeviceKernelInfo *&DeviceKernelInfoPtr() {
184202
return MDeviceKernelInfoPtr;
@@ -189,7 +207,10 @@ class __SYCL_EXPORT KernelRuntimeInfo {
189207

190208
private:
191209
detail::ABINeutralKernelNameStrT MKernelName;
192-
std::shared_ptr<detail::HostKernelBase> MHostKernel;
210+
HostKernelFactory MHostKernelFactory = nullptr;
211+
// points to the kernel function object allocated on stack, it's a lambda
212+
// function, so have to use void* here
213+
const void *MHostKernelPtr = nullptr;
193214
detail::DeviceKernelInfo *MDeviceKernelInfoPtr = nullptr;
194215
};
195216

@@ -3720,9 +3741,8 @@ class __SYCL_EXPORT queue : public detail::OwnerLessBase<queue> {
37203741
typename TransformUserItemType<Dims, LambdaArgType>::type>,
37213742
void>;
37223743

3723-
KRInfo.HostKernel().reset(
3724-
new detail::HostKernel<KernelType, TransformedArgType, Dims>(
3725-
KernelFunc));
3744+
KRInfo.SaveHostKernelRef(&KernelFunc,
3745+
detail::v1::CopyHostKernel<KernelType, TransformedArgType, Dims>);
37263746

37273747
KRInfo.KernelName() = detail::getKernelName<KernelName>();
37283748
KRInfo.DeviceKernelInfoPtr() = &detail::getDeviceKernelInfo<KernelName>();

sycl/source/detail/queue_impl.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -536,13 +536,18 @@ detail::EventImplPtr queue_impl::submit_kernel_direct_impl(
536536
std::vector<std::shared_ptr<const void>> AuxiliaryResources;
537537
bool DiscardEvent = false;
538538

539+
// At this point, HostKernel points to the lambda function allocated on
540+
// stack. To have pointer valid after submission, we need to put it in
541+
// dynamic memory.
542+
std::shared_ptr<detail::HostKernelBase> HostKernel = KRInfo.CopyHostKernel();
543+
539544
Args = extractArgsAndReqsFromLambda(
540-
KRInfo.GetKernelFuncPtr(),
545+
HostKernel->getPtr(),
541546
KRInfo.DeviceKernelInfoPtr()->ParamDescGetter,
542547
KRInfo.DeviceKernelInfoPtr()->NumParams);
543548

544549
CommandGroup.reset(new detail::CGExecKernel(
545-
std::move(NDRDesc), KRInfo.HostKernel(),
550+
std::move(NDRDesc), std::move(HostKernel),
546551
nullptr, // MKernel
547552
nullptr, // MKernelBundle
548553
std::move(CGData), std::move(Args),

0 commit comments

Comments
 (0)