Skip to content

Commit 9898e9a

Browse files
[SYCL] Postpone creation of HostKernel copy (#20240)
Create possibility to delay creation of copy of HostKernel till it became used out of submit stack, i.e. by scheduler. Do type erasure for kernel lambda via vptr in HostKernelRefBase. --------- Co-authored-by: Andrei Elovikov <[email protected]>
1 parent 161ac51 commit 9898e9a

File tree

10 files changed

+192
-62
lines changed

10 files changed

+192
-62
lines changed

sycl/include/sycl/detail/cg_types.hpp

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,66 @@ class HostKernel : public HostKernelBase {
235235
#endif
236236
};
237237

238+
// the class keeps reference to a lambda allocated externally on stack
239+
class HostKernelRefBase : public HostKernelBase {
240+
public:
241+
HostKernelRefBase() = default;
242+
HostKernelRefBase(const HostKernelRefBase &) = delete;
243+
HostKernelRefBase &operator=(const HostKernelRefBase &) = delete;
244+
245+
virtual std::unique_ptr<HostKernelBase> takeOrCopyOwnership() const = 0;
246+
#ifndef __INTEL_PREVIEW_BREAKING_CHANGES
247+
// The kernels that are passed via HostKernelRefBase are instantiated along
248+
// ctor call with GetInstantiateKernelOnHostPtr().
249+
void InstantiateKernelOnHost() override {}
250+
#endif
251+
};
252+
253+
// Primary template for movable objects.
254+
template <class KernelType, class KernelTypeUniversalRef, class KernelArgType,
255+
int Dims>
256+
class HostKernelRef : public HostKernelRefBase {
257+
KernelType &&MKernel;
258+
259+
public:
260+
HostKernelRef(KernelType &&Kernel) : MKernel(std::move(Kernel)) {}
261+
HostKernelRef(const KernelType &Kernel) = delete;
262+
263+
virtual char *getPtr() override { return reinterpret_cast<char *>(&MKernel); }
264+
virtual std::unique_ptr<HostKernelBase> takeOrCopyOwnership() const override {
265+
std::unique_ptr<HostKernelBase> Kernel;
266+
Kernel.reset(
267+
new HostKernel<KernelType, KernelArgType, Dims>(std::move(MKernel)));
268+
return Kernel;
269+
}
270+
271+
~HostKernelRef() noexcept override = default;
272+
};
273+
274+
// Specialization for copyable objects.
275+
template <class KernelType, class KernelTypeUniversalRef, class KernelArgType,
276+
int Dims>
277+
class HostKernelRef<KernelType, KernelTypeUniversalRef &, KernelArgType, Dims>
278+
: public HostKernelRefBase {
279+
const KernelType &MKernel;
280+
281+
public:
282+
HostKernelRef(const KernelType &Kernel) : MKernel(Kernel) {}
283+
284+
virtual char *getPtr() override {
285+
return const_cast<char *>(reinterpret_cast<const char *>(&MKernel));
286+
}
287+
virtual std::unique_ptr<HostKernelBase> takeOrCopyOwnership() const override {
288+
std::unique_ptr<HostKernelBase> Kernel;
289+
Kernel.reset(new HostKernel<KernelType, KernelArgType, Dims>(MKernel));
290+
return Kernel;
291+
}
292+
293+
~HostKernelRef() noexcept override = default;
294+
};
295+
238296
// This function is needed for host-side compilation to keep kernels
239-
// instantitated. This is important for debuggers to be able to associate
297+
// instantiated. This is important for debuggers to be able to associate
240298
// kernel code instructions with source code lines.
241299
template <class KernelType, class KernelArgType, int Dims>
242300
constexpr void *GetInstantiateKernelOnHostPtr() {

sycl/include/sycl/khr/free_function_commands.hpp

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -149,44 +149,48 @@ void launch_grouped(handler &h, range<3> r, range<3> size,
149149
}
150150

151151
template <typename KernelType>
152-
void launch_grouped(const queue &q, range<1> r, range<1> size,
153-
const KernelType &k,
152+
constexpr bool enable_kernel_function_overload =
153+
!std::is_same_v<typename std::decay_t<KernelType>, sycl::kernel>;
154+
155+
template <typename KernelType, typename = typename std::enable_if_t<
156+
enable_kernel_function_overload<KernelType>>>
157+
void launch_grouped(const queue &q, range<1> r, range<1> size, KernelType &&k,
154158
const sycl::detail::code_location &codeLoc =
155159
sycl::detail::code_location::current()) {
156160
#ifdef __DPCPP_ENABLE_UNFINISHED_NO_CGH_SUBMIT
157-
detail::submit_kernel_direct(q,
158-
ext::oneapi::experimental::empty_properties_t{},
159-
nd_range<1>(r, size), k);
161+
detail::submit_kernel_direct(
162+
q, ext::oneapi::experimental::empty_properties_t{}, nd_range<1>(r, size),
163+
std::forward<KernelType>(k));
160164
#else
161165
submit(
162166
q, [&](handler &h) { launch_grouped<KernelType>(h, r, size, k); },
163167
codeLoc);
164168
#endif
165169
}
166-
template <typename KernelType>
167-
void launch_grouped(const queue &q, range<2> r, range<2> size,
168-
const KernelType &k,
170+
template <typename KernelType, typename = typename std::enable_if_t<
171+
enable_kernel_function_overload<KernelType>>>
172+
void launch_grouped(const queue &q, range<2> r, range<2> size, KernelType &&k,
169173
const sycl::detail::code_location &codeLoc =
170174
sycl::detail::code_location::current()) {
171175
#ifdef __DPCPP_ENABLE_UNFINISHED_NO_CGH_SUBMIT
172-
detail::submit_kernel_direct(q,
173-
ext::oneapi::experimental::empty_properties_t{},
174-
nd_range<2>(r, size), k);
176+
detail::submit_kernel_direct(
177+
q, ext::oneapi::experimental::empty_properties_t{}, nd_range<2>(r, size),
178+
std::forward<KernelType>(k));
175179
#else
176180
submit(
177181
q, [&](handler &h) { launch_grouped<KernelType>(h, r, size, k); },
178182
codeLoc);
179183
#endif
180184
}
181-
template <typename KernelType>
182-
void launch_grouped(const queue &q, range<3> r, range<3> size,
183-
const KernelType &k,
185+
template <typename KernelType, typename = typename std::enable_if_t<
186+
enable_kernel_function_overload<KernelType>>>
187+
void launch_grouped(const queue &q, range<3> r, range<3> size, KernelType &&k,
184188
const sycl::detail::code_location &codeLoc =
185189
sycl::detail::code_location::current()) {
186190
#ifdef __DPCPP_ENABLE_UNFINISHED_NO_CGH_SUBMIT
187-
detail::submit_kernel_direct(q,
188-
ext::oneapi::experimental::empty_properties_t{},
189-
nd_range<3>(r, size), k);
191+
detail::submit_kernel_direct(
192+
q, ext::oneapi::experimental::empty_properties_t{}, nd_range<3>(r, size),
193+
std::forward<KernelType>(k));
190194
#else
191195
submit(
192196
q, [&](handler &h) { launch_grouped<KernelType>(h, r, size, k); },

sycl/include/sycl/queue.hpp

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,14 @@ auto get_native(const SyclObjectT &Obj)
6565
template <int Dims>
6666
event __SYCL_EXPORT submit_kernel_direct_with_event_impl(
6767
const queue &Queue, const nd_range<Dims> &Range,
68-
std::shared_ptr<detail::HostKernelBase> &HostKernel,
68+
detail::HostKernelRefBase &HostKernel,
6969
detail::DeviceKernelInfo *DeviceKernelInfo,
7070
const detail::code_location &CodeLoc, bool IsTopCodeLoc);
7171

7272
template <int Dims>
7373
void __SYCL_EXPORT submit_kernel_direct_without_event_impl(
7474
const queue &Queue, const nd_range<Dims> &Range,
75-
std::shared_ptr<detail::HostKernelBase> &HostKernel,
75+
detail::HostKernelRefBase &HostKernel,
7676
detail::DeviceKernelInfo *DeviceKernelInfo,
7777
const detail::code_location &CodeLoc, bool IsTopCodeLoc);
7878

@@ -157,10 +157,10 @@ class __SYCL_EXPORT SubmissionInfo {
157157
} // namespace v1
158158

159159
template <typename KernelName = detail::auto_name, bool EventNeeded = false,
160-
typename PropertiesT, typename KernelType, int Dims>
160+
typename PropertiesT, typename KernelTypeUniversalRef, int Dims>
161161
auto submit_kernel_direct(
162162
const queue &Queue, PropertiesT Props, const nd_range<Dims> &Range,
163-
const KernelType &KernelFunc,
163+
KernelTypeUniversalRef &&KernelFunc,
164164
const detail::code_location &CodeLoc = detail::code_location::current()) {
165165
// TODO Properties not supported yet
166166
(void)Props;
@@ -170,6 +170,9 @@ auto submit_kernel_direct(
170170
"Setting properties not supported yet for no-CGH kernel submit.");
171171
detail::tls_code_loc_t TlsCodeLocCapture(CodeLoc);
172172

173+
using KernelType =
174+
std::remove_const_t<std::remove_reference_t<KernelTypeUniversalRef>>;
175+
173176
using NameT =
174177
typename detail::get_kernel_name_t<KernelName, KernelType>::name;
175178
using LambdaArgType =
@@ -180,15 +183,23 @@ auto submit_kernel_direct(
180183
"must be either sycl::nd_item or be convertible from sycl::nd_item");
181184
using TransformedArgType = sycl::nd_item<Dims>;
182185

183-
std::shared_ptr<detail::HostKernelBase> HostKernel = std::make_shared<
184-
detail::HostKernel<KernelType, TransformedArgType, Dims>>(KernelFunc);
186+
detail::KernelWrapper<detail::WrapAs::parallel_for, NameT, KernelType,
187+
TransformedArgType, PropertiesT>::wrap(KernelFunc);
188+
189+
HostKernelRef<KernelType, KernelTypeUniversalRef, TransformedArgType, Dims>
190+
HostKernel(std::forward<KernelTypeUniversalRef>(KernelFunc));
191+
192+
// Instantiating the kernel on the host improves debugging.
193+
// Passing this pointer to another translation unit prevents optimization.
194+
#ifndef NDEBUG
195+
// TODO: call library to prevent dropping call due to optimization
196+
(void)
197+
detail::GetInstantiateKernelOnHostPtr<KernelType, LambdaArgType, Dims>();
198+
#endif
185199

186200
detail::DeviceKernelInfo *DeviceKernelInfoPtr =
187201
&detail::getDeviceKernelInfo<NameT>();
188202

189-
detail::KernelWrapper<detail::WrapAs::parallel_for, NameT, KernelType,
190-
TransformedArgType, PropertiesT>::wrap(KernelFunc);
191-
192203
if constexpr (EventNeeded) {
193204
return submit_kernel_direct_with_event_impl(
194205
Queue, Range, HostKernel, DeviceKernelInfoPtr,

sycl/source/detail/queue_impl.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -421,15 +421,17 @@ queue_impl::submit_impl(const detail::type_erased_cgfo_ty &CGF,
421421
}
422422

423423
detail::EventImplPtr queue_impl::submit_kernel_direct_impl(
424-
const NDRDescT &NDRDesc,
425-
std::shared_ptr<detail::HostKernelBase> &HostKernel,
424+
const NDRDescT &NDRDesc, detail::HostKernelRefBase &HostKernel,
426425
detail::DeviceKernelInfo *DeviceKernelInfo, bool CallerNeedsEvent,
427426
const detail::code_location &CodeLoc, bool IsTopCodeLoc) {
428427

429428
KernelData KData;
430429

430+
std::shared_ptr<detail::HostKernelBase> HostKernelPtr =
431+
HostKernel.takeOrCopyOwnership();
432+
431433
KData.setDeviceKernelInfoPtr(DeviceKernelInfo);
432-
KData.setKernelFunc(HostKernel->getPtr());
434+
KData.setKernelFunc(HostKernelPtr->getPtr());
433435
KData.setNDRDesc(NDRDesc);
434436

435437
auto SubmitKernelFunc =
@@ -441,7 +443,7 @@ detail::EventImplPtr queue_impl::submit_kernel_direct_impl(
441443
KData.extractArgsAndReqsFromLambda();
442444

443445
CommandGroup.reset(new detail::CGExecKernel(
444-
KData.getNDRDesc(), HostKernel,
446+
KData.getNDRDesc(), std::move(HostKernelPtr),
445447
nullptr, // Kernel
446448
nullptr, // KernelBundle
447449
std::move(CGData), std::move(KData).getArgs(),

sycl/source/detail/queue_impl.hpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -361,8 +361,7 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
361361

362362
template <int Dims>
363363
event submit_kernel_direct_with_event(
364-
const nd_range<Dims> &Range,
365-
std::shared_ptr<detail::HostKernelBase> &HostKernel,
364+
const nd_range<Dims> &Range, detail::HostKernelRefBase &HostKernel,
366365
detail::DeviceKernelInfo *DeviceKernelInfo,
367366
const detail::code_location &CodeLoc, bool IsTopCodeLoc) {
368367
detail::EventImplPtr EventImpl =
@@ -373,8 +372,7 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
373372

374373
template <int Dims>
375374
void submit_kernel_direct_without_event(
376-
const nd_range<Dims> &Range,
377-
std::shared_ptr<detail::HostKernelBase> &HostKernel,
375+
const nd_range<Dims> &Range, detail::HostKernelRefBase &HostKernel,
378376
detail::DeviceKernelInfo *DeviceKernelInfo,
379377
const detail::code_location &CodeLoc, bool IsTopCodeLoc) {
380378
submit_kernel_direct_impl(NDRDescT{Range}, HostKernel, DeviceKernelInfo,
@@ -905,8 +903,7 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
905903
///
906904
/// \return a SYCL event representing submitted command group or nullptr.
907905
detail::EventImplPtr submit_kernel_direct_impl(
908-
const NDRDescT &NDRDesc,
909-
std::shared_ptr<detail::HostKernelBase> &HostKernel,
906+
const NDRDescT &NDRDesc, detail::HostKernelRefBase &HostKernel,
910907
detail::DeviceKernelInfo *DeviceKernelInfo, bool CallerNeedsEvent,
911908
const detail::code_location &CodeLoc, bool IsTopCodeLoc);
912909

sycl/source/queue.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,7 @@ const property_list &queue::getPropList() const { return impl->getPropList(); }
474474
template <int Dims>
475475
event submit_kernel_direct_with_event_impl(
476476
const queue &Queue, const nd_range<Dims> &Range,
477-
std::shared_ptr<detail::HostKernelBase> &HostKernel,
477+
detail::HostKernelRefBase &HostKernel,
478478
detail::DeviceKernelInfo *DeviceKernelInfo,
479479
const detail::code_location &CodeLoc, bool IsTopCodeLoc) {
480480
return getSyclObjImpl(Queue)->submit_kernel_direct_with_event(
@@ -483,26 +483,26 @@ event submit_kernel_direct_with_event_impl(
483483

484484
template event __SYCL_EXPORT submit_kernel_direct_with_event_impl<1>(
485485
const queue &Queue, const nd_range<1> &Range,
486-
std::shared_ptr<detail::HostKernelBase> &HostKernel,
486+
detail::HostKernelRefBase &HostKernel,
487487
detail::DeviceKernelInfo *DeviceKernelInfo,
488488
const detail::code_location &CodeLoc, bool IsTopCodeLoc);
489489

490490
template event __SYCL_EXPORT submit_kernel_direct_with_event_impl<2>(
491491
const queue &Queue, const nd_range<2> &Range,
492-
std::shared_ptr<detail::HostKernelBase> &HostKernel,
492+
detail::HostKernelRefBase &HostKernel,
493493
detail::DeviceKernelInfo *DeviceKernelInfo,
494494
const detail::code_location &CodeLoc, bool IsTopCodeLoc);
495495

496496
template event __SYCL_EXPORT submit_kernel_direct_with_event_impl<3>(
497497
const queue &Queue, const nd_range<3> &Range,
498-
std::shared_ptr<detail::HostKernelBase> &HostKernel,
498+
detail::HostKernelRefBase &HostKernel,
499499
detail::DeviceKernelInfo *DeviceKernelInfo,
500500
const detail::code_location &CodeLoc, bool IsTopCodeLoc);
501501

502502
template <int Dims>
503503
void submit_kernel_direct_without_event_impl(
504504
const queue &Queue, const nd_range<Dims> &Range,
505-
std::shared_ptr<detail::HostKernelBase> &HostKernel,
505+
detail::HostKernelRefBase &HostKernel,
506506
detail::DeviceKernelInfo *DeviceKernelInfo,
507507
const detail::code_location &CodeLoc, bool IsTopCodeLoc) {
508508
getSyclObjImpl(Queue)->submit_kernel_direct_without_event(
@@ -511,19 +511,19 @@ void submit_kernel_direct_without_event_impl(
511511

512512
template void __SYCL_EXPORT submit_kernel_direct_without_event_impl<1>(
513513
const queue &Queue, const nd_range<1> &Range,
514-
std::shared_ptr<detail::HostKernelBase> &HostKernel,
514+
detail::HostKernelRefBase &HostKernel,
515515
detail::DeviceKernelInfo *DeviceKernelInfo,
516516
const detail::code_location &CodeLoc, bool IsTopCodeLoc);
517517

518518
template void __SYCL_EXPORT submit_kernel_direct_without_event_impl<2>(
519519
const queue &Queue, const nd_range<2> &Range,
520-
std::shared_ptr<detail::HostKernelBase> &HostKernel,
520+
detail::HostKernelRefBase &HostKernel,
521521
detail::DeviceKernelInfo *DeviceKernelInfo,
522522
const detail::code_location &CodeLoc, bool IsTopCodeLoc);
523523

524524
template void __SYCL_EXPORT submit_kernel_direct_without_event_impl<3>(
525525
const queue &Queue, const nd_range<3> &Range,
526-
std::shared_ptr<detail::HostKernelBase> &HostKernel,
526+
detail::HostKernelRefBase &HostKernel,
527527
detail::DeviceKernelInfo *DeviceKernelInfo,
528528
const detail::code_location &CodeLoc, bool IsTopCodeLoc);
529529

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// RUN: %clangxx -fsycl -c -fno-color-diagnostics -Xclang -fdump-record-layouts %s -o %t.out | FileCheck %s
2+
// REQUIRES: linux
3+
// UNSUPPORTED: libcxx
4+
5+
// clang-format off
6+
7+
#include <sycl/detail/cg_types.hpp>
8+
9+
void foo(sycl::detail::HostKernelRefBase *) {}
10+
11+
// CHECK: 0 | class sycl::detail::HostKernelRefBase
12+
// CHECK-NEXT: 0 | class sycl::detail::HostKernelBase (primary base)
13+
// CHECK-NEXT: 0 | (HostKernelBase vtable pointer)
14+
// CHECK-NEXT: | [sizeof=8, dsize=8, align=8,
15+
// CHECK-NEXT: | nvsize=8, nvalign=8]

sycl/test/abi/sycl_symbols_linux.dump

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2985,12 +2985,12 @@ _ZN4sycl3_V121__isgreaterequal_implEdd
29852985
_ZN4sycl3_V121__isgreaterequal_implEff
29862986
_ZN4sycl3_V122accelerator_selector_vERKNS0_6deviceE
29872987
_ZN4sycl3_V128verifyUSMAllocatorPropertiesERKNS0_13property_listE
2988-
_ZN4sycl3_V136submit_kernel_direct_with_event_implILi1EEENS0_5eventERKNS0_5queueERKNS0_8nd_rangeIXT_EEERSt10shared_ptrINS0_6detail14HostKernelBaseEEPNSB_16DeviceKernelInfoERKNSB_13code_locationEb
2989-
_ZN4sycl3_V136submit_kernel_direct_with_event_implILi2EEENS0_5eventERKNS0_5queueERKNS0_8nd_rangeIXT_EEERSt10shared_ptrINS0_6detail14HostKernelBaseEEPNSB_16DeviceKernelInfoERKNSB_13code_locationEb
2990-
_ZN4sycl3_V136submit_kernel_direct_with_event_implILi3EEENS0_5eventERKNS0_5queueERKNS0_8nd_rangeIXT_EEERSt10shared_ptrINS0_6detail14HostKernelBaseEEPNSB_16DeviceKernelInfoERKNSB_13code_locationEb
2991-
_ZN4sycl3_V139submit_kernel_direct_without_event_implILi1EEEvRKNS0_5queueERKNS0_8nd_rangeIXT_EEERSt10shared_ptrINS0_6detail14HostKernelBaseEEPNSA_16DeviceKernelInfoERKNSA_13code_locationEb
2992-
_ZN4sycl3_V139submit_kernel_direct_without_event_implILi2EEEvRKNS0_5queueERKNS0_8nd_rangeIXT_EEERSt10shared_ptrINS0_6detail14HostKernelBaseEEPNSA_16DeviceKernelInfoERKNSA_13code_locationEb
2993-
_ZN4sycl3_V139submit_kernel_direct_without_event_implILi3EEEvRKNS0_5queueERKNS0_8nd_rangeIXT_EEERSt10shared_ptrINS0_6detail14HostKernelBaseEEPNSA_16DeviceKernelInfoERKNSA_13code_locationEb
2988+
_ZN4sycl3_V139submit_kernel_direct_without_event_implILi1EEEvRKNS0_5queueERKNS0_8nd_rangeIXT_EEERNS0_6detail17HostKernelRefBaseEPNS9_16DeviceKernelInfoERKNS9_13code_locationEb
2989+
_ZN4sycl3_V136submit_kernel_direct_with_event_implILi1EEENS0_5eventERKNS0_5queueERKNS0_8nd_rangeIXT_EEERNS0_6detail17HostKernelRefBaseEPNSA_16DeviceKernelInfoERKNSA_13code_locationEb
2990+
_ZN4sycl3_V136submit_kernel_direct_with_event_implILi3EEENS0_5eventERKNS0_5queueERKNS0_8nd_rangeIXT_EEERNS0_6detail17HostKernelRefBaseEPNSA_16DeviceKernelInfoERKNSA_13code_locationEb
2991+
_ZN4sycl3_V139submit_kernel_direct_without_event_implILi2EEEvRKNS0_5queueERKNS0_8nd_rangeIXT_EEERNS0_6detail17HostKernelRefBaseEPNS9_16DeviceKernelInfoERKNS9_13code_locationEb
2992+
_ZN4sycl3_V139submit_kernel_direct_without_event_implILi3EEEvRKNS0_5queueERKNS0_8nd_rangeIXT_EEERNS0_6detail17HostKernelRefBaseEPNS9_16DeviceKernelInfoERKNS9_13code_locationEb
2993+
_ZN4sycl3_V136submit_kernel_direct_with_event_implILi2EEENS0_5eventERKNS0_5queueERKNS0_8nd_rangeIXT_EEERNS0_6detail17HostKernelRefBaseEPNSA_16DeviceKernelInfoERKNSA_13code_locationEb
29942994
_ZN4sycl3_V13ext5intel12experimental9pipe_base13get_pipe_nameB5cxx11EPKv
29952995
_ZN4sycl3_V13ext5intel12experimental9pipe_base17wait_non_blockingERKNS0_5eventE
29962996
_ZN4sycl3_V13ext5intel12experimental9pipe_base18get_pipe_name_implEPKv

0 commit comments

Comments
 (0)