@@ -151,6 +151,17 @@ class __SYCL_EXPORT SubmissionInfo {
151151 ext::oneapi::experimental::event_mode_enum::none;
152152};
153153
154+ template <typename KernelType, typename TransformedArgType, int Dims>
155+ std::shared_ptr<detail::HostKernelBase> CopyHostKernel (const void *KernelFunc) {
156+ const KernelType &KernelFuncRef = *static_cast <const KernelType*>(KernelFunc);
157+ std::shared_ptr<detail::HostKernelBase> HostKernel;
158+ HostKernel.reset (new detail::HostKernel<KernelType, TransformedArgType, Dims>(
159+ KernelFuncRef));
160+ return HostKernel;
161+ }
162+
163+ using HostKernelFactory = std::shared_ptr<detail::HostKernelBase>(*)(const void *);
164+
154165using KernelParamDescGetterFuncPtr = detail::kernel_param_desc_t (*)(int );
155166
156167// This class is intended to store the kernel runtime information,
@@ -172,13 +183,20 @@ class __SYCL_EXPORT KernelRuntimeInfo {
172183 return MKernelName;
173184 }
174185
175- std::shared_ptr<detail::HostKernelBase> &HostKernel () { return MHostKernel; }
176- const std::shared_ptr<detail::HostKernelBase> &HostKernel () const {
177- return MHostKernel;
186+ char *GetKernelFuncPtr () const {
187+ return static_cast <char *>(const_cast <void *>(MHostKernelPtr));
178188 }
179189
180- char *GetKernelFuncPtr () { return (*MHostKernel).getPtr (); }
181- char *GetKernelFuncPtr () const { return (*MHostKernel).getPtr (); }
190+ void SaveHostKernelRef (const void *KernelFuncPtr, HostKernelFactory Factory) {
191+ MHostKernelFactory = Factory;
192+ MHostKernelPtr = KernelFuncPtr;
193+ }
194+
195+ std::shared_ptr<detail::HostKernelBase> CopyHostKernel () const {
196+ if (MHostKernelFactory && MHostKernelPtr)
197+ return MHostKernelFactory (MHostKernelPtr);
198+ return nullptr ;
199+ }
182200
183201 detail::DeviceKernelInfo *&DeviceKernelInfoPtr () {
184202 return MDeviceKernelInfoPtr;
@@ -189,7 +207,10 @@ class __SYCL_EXPORT KernelRuntimeInfo {
189207
190208private:
191209 detail::ABINeutralKernelNameStrT MKernelName;
192- std::shared_ptr<detail::HostKernelBase> MHostKernel;
210+ HostKernelFactory MHostKernelFactory = nullptr ;
211+ // points to the kernel function object allocated on stack, it's a lambda
212+ // function, so have to use void* here
213+ const void *MHostKernelPtr = nullptr ;
193214 detail::DeviceKernelInfo *MDeviceKernelInfoPtr = nullptr ;
194215};
195216
@@ -3720,9 +3741,8 @@ class __SYCL_EXPORT queue : public detail::OwnerLessBase<queue> {
37203741 typename TransformUserItemType<Dims, LambdaArgType>::type>,
37213742 void >;
37223743
3723- KRInfo.HostKernel ().reset (
3724- new detail::HostKernel<KernelType, TransformedArgType, Dims>(
3725- KernelFunc));
3744+ KRInfo.SaveHostKernelRef (&KernelFunc,
3745+ detail::v1::CopyHostKernel<KernelType, TransformedArgType, Dims>);
37263746
37273747 KRInfo.KernelName () = detail::getKernelName<KernelName>();
37283748 KRInfo.DeviceKernelInfoPtr () = &detail::getDeviceKernelInfo<KernelName>();
0 commit comments