Skip to content

Commit cc71f38

Browse files
[SYCL] Postpone creation of HostKernel copy
Create possibility to delay creation of copy of HostKernel till it became used out of submit stack, i.e. by scheduler. Do type erasure for kernel lambda via vptr in HostKernelRefBase.
1 parent a671d25 commit cc71f38

File tree

5 files changed

+114
-18
lines changed

5 files changed

+114
-18
lines changed

sycl/include/sycl/detail/cg_types.hpp

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,93 @@ class HostKernel : public HostKernelBase {
235235
#endif
236236
};
237237

238+
// the class keeps reference to a lambda allocated externally on stack
239+
class HostKernelRefBase : public HostKernelBase {
240+
public:
241+
virtual std::shared_ptr<HostKernelBase> takeOrCopyOwnership() const = 0;
242+
};
243+
244+
template <class KernelType, class KernelArgType, int Dims>
245+
class HostKernelRef : public HostKernelRefBase {
246+
const KernelType &MKernel;
247+
248+
public:
249+
HostKernelRef(const KernelType &Kernel) : MKernel(Kernel) {}
250+
251+
virtual char *getPtr() override {
252+
return const_cast<char *>(reinterpret_cast<const char *>(&MKernel));
253+
}
254+
virtual std::shared_ptr<HostKernelBase> takeOrCopyOwnership() const override {
255+
std::shared_ptr<HostKernelBase> Kernel;
256+
Kernel.reset(new HostKernel<KernelType, KernelArgType, Dims>(MKernel));
257+
return Kernel;
258+
}
259+
260+
~HostKernelRef() noexcept override = default;
261+
#ifndef __INTEL_PREVIEW_BREAKING_CHANGES
262+
// This function is needed for host-side compilation to keep kernels
263+
// instantitated. This is important for debuggers to be able to associate
264+
// kernel code instructions with source code lines.
265+
// NOTE: InstatiateKernelOnHost() should not be called.
266+
void InstantiateKernelOnHost() override {
267+
using IDBuilder = sycl::detail::Builder;
268+
constexpr bool HasKernelHandlerArg =
269+
KernelLambdaHasKernelHandlerArgT<KernelType, KernelArgType>::value;
270+
if constexpr (std::is_same_v<KernelArgType, void>) {
271+
runKernelWithoutArg(MKernel, std::bool_constant<HasKernelHandlerArg>());
272+
} else if constexpr (std::is_same_v<KernelArgType, sycl::id<Dims>>) {
273+
sycl::id ID = InitializedVal<Dims, id>::template get<0>();
274+
runKernelWithArg<const KernelArgType &>(
275+
MKernel, ID, std::bool_constant<HasKernelHandlerArg>());
276+
} else if constexpr (std::is_same_v<KernelArgType, item<Dims, true>> ||
277+
std::is_same_v<KernelArgType, item<Dims, false>>) {
278+
constexpr bool HasOffset =
279+
std::is_same_v<KernelArgType, item<Dims, true>>;
280+
if constexpr (!HasOffset) {
281+
KernelArgType Item = IDBuilder::createItem<Dims, HasOffset>(
282+
InitializedVal<Dims, range>::template get<1>(),
283+
InitializedVal<Dims, id>::template get<0>());
284+
runKernelWithArg<KernelArgType>(
285+
MKernel, Item, std::bool_constant<HasKernelHandlerArg>());
286+
} else {
287+
KernelArgType Item = IDBuilder::createItem<Dims, HasOffset>(
288+
InitializedVal<Dims, range>::template get<1>(),
289+
InitializedVal<Dims, id>::template get<0>(),
290+
InitializedVal<Dims, id>::template get<0>());
291+
runKernelWithArg<KernelArgType>(
292+
MKernel, Item, std::bool_constant<HasKernelHandlerArg>());
293+
}
294+
} else if constexpr (std::is_same_v<KernelArgType, nd_item<Dims>>) {
295+
sycl::range<Dims> Range = InitializedVal<Dims, range>::template get<1>();
296+
sycl::id<Dims> ID = InitializedVal<Dims, id>::template get<0>();
297+
sycl::group<Dims> Group =
298+
IDBuilder::createGroup<Dims>(Range, Range, Range, ID);
299+
sycl::item<Dims, true> GlobalItem =
300+
IDBuilder::createItem<Dims, true>(Range, ID, ID);
301+
sycl::item<Dims, false> LocalItem =
302+
IDBuilder::createItem<Dims, false>(Range, ID);
303+
KernelArgType NDItem =
304+
IDBuilder::createNDItem<Dims>(GlobalItem, LocalItem, Group);
305+
runKernelWithArg<const KernelArgType>(
306+
MKernel, NDItem, std::bool_constant<HasKernelHandlerArg>());
307+
} else if constexpr (std::is_same_v<KernelArgType, sycl::group<Dims>>) {
308+
sycl::range<Dims> Range = InitializedVal<Dims, range>::template get<1>();
309+
sycl::id<Dims> ID = InitializedVal<Dims, id>::template get<0>();
310+
KernelArgType Group =
311+
IDBuilder::createGroup<Dims>(Range, Range, Range, ID);
312+
runKernelWithArg<KernelArgType>(
313+
MKernel, Group, std::bool_constant<HasKernelHandlerArg>());
314+
} else {
315+
// Assume that anything else can be default-constructed. If not, this
316+
// should fail to compile and the implementor should implement a generic
317+
// case for the new argument type.
318+
runKernelWithArg<KernelArgType>(
319+
MKernel, KernelArgType{}, std::bool_constant<HasKernelHandlerArg>());
320+
}
321+
}
322+
#endif
323+
};
324+
238325
// This function is needed for host-side compilation to keep kernels
239326
// instantitated. This is important for debuggers to be able to associate
240327
// kernel code instructions with source code lines.

sycl/include/sycl/queue.hpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,14 @@ auto get_native(const SyclObjectT &Obj)
6565
template <int Dims>
6666
event __SYCL_EXPORT submit_kernel_direct_with_event_impl(
6767
const queue &Queue, const nd_range<Dims> &Range,
68-
std::shared_ptr<detail::HostKernelBase> &HostKernel,
68+
detail::HostKernelRefBase &HostKernel,
6969
detail::DeviceKernelInfo *DeviceKernelInfo,
7070
const detail::code_location &CodeLoc, bool IsTopCodeLoc);
7171

7272
template <int Dims>
7373
void __SYCL_EXPORT submit_kernel_direct_without_event_impl(
7474
const queue &Queue, const nd_range<Dims> &Range,
75-
std::shared_ptr<detail::HostKernelBase> &HostKernel,
75+
detail::HostKernelRefBase &HostKernel,
7676
detail::DeviceKernelInfo *DeviceKernelInfo,
7777
const detail::code_location &CodeLoc, bool IsTopCodeLoc);
7878

@@ -180,8 +180,15 @@ auto submit_kernel_direct(
180180
"must be either sycl::nd_item or be convertible from sycl::nd_item");
181181
using TransformedArgType = sycl::nd_item<Dims>;
182182

183-
std::shared_ptr<detail::HostKernelBase> HostKernel = std::make_shared<
184-
detail::HostKernel<KernelType, TransformedArgType, Dims>>(KernelFunc);
183+
HostKernelRef<KernelType, TransformedArgType, Dims> HostKernel(KernelFunc);
184+
185+
// Instantiating the kernel on the host improves debugging.
186+
// Passing this pointer to another translation unit prevents optimization.
187+
#ifndef NDEBUG
188+
// TODO: call library to prevent dropping call due to optimization
189+
(void)detail::GetInstantiateKernelOnHostPtr<KernelType, LambdaArgType,
190+
Dims>();
191+
#endif
185192

186193
detail::DeviceKernelInfo *DeviceKernelInfoPtr =
187194
&detail::getDeviceKernelInfo<NameT>();

sycl/source/detail/queue_impl.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -422,14 +422,16 @@ queue_impl::submit_impl(const detail::type_erased_cgfo_ty &CGF,
422422

423423
detail::EventImplPtr queue_impl::submit_kernel_direct_impl(
424424
const NDRDescT &NDRDesc,
425-
std::shared_ptr<detail::HostKernelBase> &HostKernel,
425+
detail::HostKernelRefBase &HostKernel,
426426
detail::DeviceKernelInfo *DeviceKernelInfo, bool CallerNeedsEvent,
427427
const detail::code_location &CodeLoc, bool IsTopCodeLoc) {
428428

429429
KernelData KData;
430430

431+
std::shared_ptr<detail::HostKernelBase> HostKernelPtr = HostKernel.takeOrCopyOwnership();
432+
431433
KData.setDeviceKernelInfoPtr(DeviceKernelInfo);
432-
KData.setKernelFunc(HostKernel->getPtr());
434+
KData.setKernelFunc(HostKernelPtr->getPtr());
433435
KData.setNDRDesc(NDRDesc);
434436

435437
auto SubmitKernelFunc =
@@ -441,7 +443,7 @@ detail::EventImplPtr queue_impl::submit_kernel_direct_impl(
441443
KData.extractArgsAndReqsFromLambda();
442444

443445
CommandGroup.reset(new detail::CGExecKernel(
444-
KData.getNDRDesc(), HostKernel,
446+
KData.getNDRDesc(), std::move(HostKernelPtr),
445447
nullptr, // Kernel
446448
nullptr, // KernelBundle
447449
std::move(CGData), std::move(KData).getArgs(),

sycl/source/detail/queue_impl.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
362362
template <int Dims>
363363
event submit_kernel_direct_with_event(
364364
const nd_range<Dims> &Range,
365-
std::shared_ptr<detail::HostKernelBase> &HostKernel,
365+
detail::HostKernelRefBase &HostKernel,
366366
detail::DeviceKernelInfo *DeviceKernelInfo,
367367
const detail::code_location &CodeLoc, bool IsTopCodeLoc) {
368368
detail::EventImplPtr EventImpl =
@@ -374,7 +374,7 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
374374
template <int Dims>
375375
void submit_kernel_direct_without_event(
376376
const nd_range<Dims> &Range,
377-
std::shared_ptr<detail::HostKernelBase> &HostKernel,
377+
detail::HostKernelRefBase &HostKernel,
378378
detail::DeviceKernelInfo *DeviceKernelInfo,
379379
const detail::code_location &CodeLoc, bool IsTopCodeLoc) {
380380
submit_kernel_direct_impl(NDRDescT{Range}, HostKernel, DeviceKernelInfo,
@@ -906,7 +906,7 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
906906
/// \return a SYCL event representing submitted command group or nullptr.
907907
detail::EventImplPtr submit_kernel_direct_impl(
908908
const NDRDescT &NDRDesc,
909-
std::shared_ptr<detail::HostKernelBase> &HostKernel,
909+
detail::HostKernelRefBase &HostKernel,
910910
detail::DeviceKernelInfo *DeviceKernelInfo, bool CallerNeedsEvent,
911911
const detail::code_location &CodeLoc, bool IsTopCodeLoc);
912912

sycl/source/queue.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,7 @@ const property_list &queue::getPropList() const { return impl->getPropList(); }
466466
template <int Dims>
467467
event submit_kernel_direct_with_event_impl(
468468
const queue &Queue, const nd_range<Dims> &Range,
469-
std::shared_ptr<detail::HostKernelBase> &HostKernel,
469+
detail::HostKernelRefBase &HostKernel,
470470
detail::DeviceKernelInfo *DeviceKernelInfo,
471471
const detail::code_location &CodeLoc, bool IsTopCodeLoc) {
472472
return getSyclObjImpl(Queue)->submit_kernel_direct_with_event(
@@ -475,26 +475,26 @@ event submit_kernel_direct_with_event_impl(
475475

476476
template event __SYCL_EXPORT submit_kernel_direct_with_event_impl<1>(
477477
const queue &Queue, const nd_range<1> &Range,
478-
std::shared_ptr<detail::HostKernelBase> &HostKernel,
478+
detail::HostKernelRefBase &HostKernel,
479479
detail::DeviceKernelInfo *DeviceKernelInfo,
480480
const detail::code_location &CodeLoc, bool IsTopCodeLoc);
481481

482482
template event __SYCL_EXPORT submit_kernel_direct_with_event_impl<2>(
483483
const queue &Queue, const nd_range<2> &Range,
484-
std::shared_ptr<detail::HostKernelBase> &HostKernel,
484+
detail::HostKernelRefBase &HostKernel,
485485
detail::DeviceKernelInfo *DeviceKernelInfo,
486486
const detail::code_location &CodeLoc, bool IsTopCodeLoc);
487487

488488
template event __SYCL_EXPORT submit_kernel_direct_with_event_impl<3>(
489489
const queue &Queue, const nd_range<3> &Range,
490-
std::shared_ptr<detail::HostKernelBase> &HostKernel,
490+
detail::HostKernelRefBase &HostKernel,
491491
detail::DeviceKernelInfo *DeviceKernelInfo,
492492
const detail::code_location &CodeLoc, bool IsTopCodeLoc);
493493

494494
template <int Dims>
495495
void submit_kernel_direct_without_event_impl(
496496
const queue &Queue, const nd_range<Dims> &Range,
497-
std::shared_ptr<detail::HostKernelBase> &HostKernel,
497+
detail::HostKernelRefBase &HostKernel,
498498
detail::DeviceKernelInfo *DeviceKernelInfo,
499499
const detail::code_location &CodeLoc, bool IsTopCodeLoc) {
500500
getSyclObjImpl(Queue)->submit_kernel_direct_without_event(
@@ -503,19 +503,19 @@ void submit_kernel_direct_without_event_impl(
503503

504504
template void __SYCL_EXPORT submit_kernel_direct_without_event_impl<1>(
505505
const queue &Queue, const nd_range<1> &Range,
506-
std::shared_ptr<detail::HostKernelBase> &HostKernel,
506+
detail::HostKernelRefBase &HostKernel,
507507
detail::DeviceKernelInfo *DeviceKernelInfo,
508508
const detail::code_location &CodeLoc, bool IsTopCodeLoc);
509509

510510
template void __SYCL_EXPORT submit_kernel_direct_without_event_impl<2>(
511511
const queue &Queue, const nd_range<2> &Range,
512-
std::shared_ptr<detail::HostKernelBase> &HostKernel,
512+
detail::HostKernelRefBase &HostKernel,
513513
detail::DeviceKernelInfo *DeviceKernelInfo,
514514
const detail::code_location &CodeLoc, bool IsTopCodeLoc);
515515

516516
template void __SYCL_EXPORT submit_kernel_direct_without_event_impl<3>(
517517
const queue &Queue, const nd_range<3> &Range,
518-
std::shared_ptr<detail::HostKernelBase> &HostKernel,
518+
detail::HostKernelRefBase &HostKernel,
519519
detail::DeviceKernelInfo *DeviceKernelInfo,
520520
const detail::code_location &CodeLoc, bool IsTopCodeLoc);
521521

0 commit comments

Comments
 (0)