@@ -235,6 +235,93 @@ class HostKernel : public HostKernelBase {
235235#endif
236236};
237237
238+ // the class keeps reference to a lambda allocated externally on stack
239+ class HostKernelRefBase : public HostKernelBase {
240+ public:
241+ virtual std::shared_ptr<HostKernelBase> takeOrCopyOwnership () const = 0;
242+ };
243+
244+ template <class KernelType , class KernelArgType , int Dims>
245+ class HostKernelRef : public HostKernelRefBase {
246+ const KernelType &MKernel;
247+
248+ public:
249+ HostKernelRef (const KernelType &Kernel) : MKernel(Kernel) {}
250+
251+ virtual char *getPtr () override {
252+ return const_cast <char *>(reinterpret_cast <const char *>(&MKernel));
253+ }
254+ virtual std::shared_ptr<HostKernelBase> takeOrCopyOwnership () const override {
255+ std::shared_ptr<HostKernelBase> Kernel;
256+ Kernel.reset (new HostKernel<KernelType, KernelArgType, Dims>(MKernel));
257+ return Kernel;
258+ }
259+
260+ ~HostKernelRef () noexcept override = default ;
261+ #ifndef __INTEL_PREVIEW_BREAKING_CHANGES
262+ // This function is needed for host-side compilation to keep kernels
263+ // instantitated. This is important for debuggers to be able to associate
264+ // kernel code instructions with source code lines.
265+ // NOTE: InstatiateKernelOnHost() should not be called.
266+ void InstantiateKernelOnHost () override {
267+ using IDBuilder = sycl::detail::Builder;
268+ constexpr bool HasKernelHandlerArg =
269+ KernelLambdaHasKernelHandlerArgT<KernelType, KernelArgType>::value;
270+ if constexpr (std::is_same_v<KernelArgType, void >) {
271+ runKernelWithoutArg (MKernel, std::bool_constant<HasKernelHandlerArg>());
272+ } else if constexpr (std::is_same_v<KernelArgType, sycl::id<Dims>>) {
273+ sycl::id ID = InitializedVal<Dims, id>::template get<0 >();
274+ runKernelWithArg<const KernelArgType &>(
275+ MKernel, ID, std::bool_constant<HasKernelHandlerArg>());
276+ } else if constexpr (std::is_same_v<KernelArgType, item<Dims, true >> ||
277+ std::is_same_v<KernelArgType, item<Dims, false >>) {
278+ constexpr bool HasOffset =
279+ std::is_same_v<KernelArgType, item<Dims, true >>;
280+ if constexpr (!HasOffset) {
281+ KernelArgType Item = IDBuilder::createItem<Dims, HasOffset>(
282+ InitializedVal<Dims, range>::template get<1 >(),
283+ InitializedVal<Dims, id>::template get<0 >());
284+ runKernelWithArg<KernelArgType>(
285+ MKernel, Item, std::bool_constant<HasKernelHandlerArg>());
286+ } else {
287+ KernelArgType Item = IDBuilder::createItem<Dims, HasOffset>(
288+ InitializedVal<Dims, range>::template get<1 >(),
289+ InitializedVal<Dims, id>::template get<0 >(),
290+ InitializedVal<Dims, id>::template get<0 >());
291+ runKernelWithArg<KernelArgType>(
292+ MKernel, Item, std::bool_constant<HasKernelHandlerArg>());
293+ }
294+ } else if constexpr (std::is_same_v<KernelArgType, nd_item<Dims>>) {
295+ sycl::range<Dims> Range = InitializedVal<Dims, range>::template get<1 >();
296+ sycl::id<Dims> ID = InitializedVal<Dims, id>::template get<0 >();
297+ sycl::group<Dims> Group =
298+ IDBuilder::createGroup<Dims>(Range, Range, Range, ID);
299+ sycl::item<Dims, true > GlobalItem =
300+ IDBuilder::createItem<Dims, true >(Range, ID, ID);
301+ sycl::item<Dims, false > LocalItem =
302+ IDBuilder::createItem<Dims, false >(Range, ID);
303+ KernelArgType NDItem =
304+ IDBuilder::createNDItem<Dims>(GlobalItem, LocalItem, Group);
305+ runKernelWithArg<const KernelArgType>(
306+ MKernel, NDItem, std::bool_constant<HasKernelHandlerArg>());
307+ } else if constexpr (std::is_same_v<KernelArgType, sycl::group<Dims>>) {
308+ sycl::range<Dims> Range = InitializedVal<Dims, range>::template get<1 >();
309+ sycl::id<Dims> ID = InitializedVal<Dims, id>::template get<0 >();
310+ KernelArgType Group =
311+ IDBuilder::createGroup<Dims>(Range, Range, Range, ID);
312+ runKernelWithArg<KernelArgType>(
313+ MKernel, Group, std::bool_constant<HasKernelHandlerArg>());
314+ } else {
315+ // Assume that anything else can be default-constructed. If not, this
316+ // should fail to compile and the implementor should implement a generic
317+ // case for the new argument type.
318+ runKernelWithArg<KernelArgType>(
319+ MKernel, KernelArgType{}, std::bool_constant<HasKernelHandlerArg>());
320+ }
321+ }
322+ #endif
323+ };
324+
238325// This function is needed for host-side compilation to keep kernels
239326// instantitated. This is important for debuggers to be able to associate
240327// kernel code instructions with source code lines.
0 commit comments