Skip to content

Commit 29f60a4

Browse files
[SYCL] Avoid unnecessary kernel copies
1 parent a2baaa8 commit 29f60a4

File tree

3 files changed

+70
-22
lines changed

3 files changed

+70
-22
lines changed

sycl/include/sycl/detail/cg_types.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,8 @@ class HostKernel : public HostKernelBase {
171171
friend class sycl::handler;
172172

173173
public:
174-
HostKernel(KernelType Kernel) : MKernel(Kernel) {}
174+
HostKernel(const KernelType &Kernel) : MKernel(Kernel) {}
175+
HostKernel(KernelType &&Kernel) : MKernel(std::move(Kernel)) {}
175176

176177
char *getPtr() override { return reinterpret_cast<char *>(&MKernel); }
177178

sycl/include/sycl/handler.hpp

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -715,17 +715,17 @@ class __SYCL_EXPORT handler {
715715
/// \param KernelFunc is a SYCL kernel function
716716
/// \param ParamDescs is the vector of kernel parameter descriptors.
717717
template <typename KernelName, typename KernelType, int Dims,
718-
typename LambdaArgType>
719-
void StoreLambda(KernelType KernelFunc) {
718+
typename LambdaArgType, typename KernelTypeUniversalRef>
719+
void StoreLambda(KernelTypeUniversalRef &&KernelFunc) {
720720
constexpr bool IsCallableWithKernelHandler =
721721
detail::KernelLambdaHasKernelHandlerArgT<KernelType,
722722
LambdaArgType>::value;
723723

724724
// Not using `std::make_unique` to avoid unnecessary instantiations of
725725
// `std::unique_ptr<HostKernel<...>>`. Only
726726
// `std::unique_ptr<HostKernelBase>` is necessary.
727-
MHostKernel.reset(
728-
new detail::HostKernel<KernelType, LambdaArgType, Dims>(KernelFunc));
727+
MHostKernel.reset(new detail::HostKernel<KernelType, LambdaArgType, Dims>(
728+
std::forward<KernelTypeUniversalRef>(KernelFunc)));
729729

730730
constexpr bool KernelHasName =
731731
detail::getKernelName<KernelName>() != nullptr &&
@@ -739,7 +739,7 @@ class __SYCL_EXPORT handler {
739739
#ifdef __INTEL_SYCL_USE_INTEGRATION_HEADERS
740740
static_assert(
741741
!KernelHasName ||
742-
sizeof(KernelFunc) == detail::getKernelSize<KernelName>(),
742+
sizeof(KernelType) == detail::getKernelSize<KernelName>(),
743743
"Unexpected kernel lambda size. This can be caused by an "
744744
"external host compiler producing a lambda with an "
745745
"unexpected layout. This is a limitation of the compiler."
@@ -1133,7 +1133,7 @@ class __SYCL_EXPORT handler {
11331133
typename KernelName, typename KernelType, int Dims,
11341134
typename PropertiesT = ext::oneapi::experimental::empty_properties_t>
11351135
void parallel_for_lambda_impl(range<Dims> UserRange, PropertiesT Props,
1136-
KernelType KernelFunc) {
1136+
const KernelType &KernelFunc) {
11371137
#ifndef __SYCL_DEVICE_ONLY__
11381138
throwIfActionIsCreated();
11391139
throwOnKernelParameterMisuse<KernelName, KernelType>();
@@ -1545,19 +1545,21 @@ class __SYCL_EXPORT handler {
15451545
// methods side.
15461546

15471547
template <typename... TypesToForward, typename... ArgsTy>
1548-
static void kernel_single_task_unpack(handler *h, ArgsTy... Args) {
1549-
h->kernel_single_task<TypesToForward..., Props...>(Args...);
1548+
static void kernel_single_task_unpack(handler *h, ArgsTy&&... Args) {
1549+
h->kernel_single_task<TypesToForward..., Props...>(std::forward<ArgsTy>(Args)...);
15501550
}
15511551

15521552
template <typename... TypesToForward, typename... ArgsTy>
1553-
static void kernel_parallel_for_unpack(handler *h, ArgsTy... Args) {
1554-
h->kernel_parallel_for<TypesToForward..., Props...>(Args...);
1553+
static void kernel_parallel_for_unpack(handler *h, ArgsTy &&...Args) {
1554+
h->kernel_parallel_for<TypesToForward..., Props...>(
1555+
std::forward<ArgsTy>(Args)...);
15551556
}
15561557

15571558
template <typename... TypesToForward, typename... ArgsTy>
15581559
static void kernel_parallel_for_work_group_unpack(handler *h,
1559-
ArgsTy... Args) {
1560-
h->kernel_parallel_for_work_group<TypesToForward..., Props...>(Args...);
1560+
ArgsTy &&...Args) {
1561+
h->kernel_parallel_for_work_group<TypesToForward..., Props...>(
1562+
std::forward<ArgsTy>(Args)...);
15611563
}
15621564
};
15631565

@@ -1622,9 +1624,9 @@ class __SYCL_EXPORT handler {
16221624
void kernel_single_task_wrapper(const KernelType &KernelFunc) {
16231625
unpack<KernelName, KernelType, PropertiesT,
16241626
detail::KernelLambdaHasKernelHandlerArgT<KernelType>::value>(
1625-
KernelFunc, [&](auto Unpacker, auto... args) {
1627+
KernelFunc, [&](auto Unpacker, auto &&...args) {
16261628
Unpacker.template kernel_single_task_unpack<KernelName, KernelType>(
1627-
args...);
1629+
std::forward<decltype(args)>(args)...);
16281630
});
16291631
}
16301632

@@ -1635,9 +1637,10 @@ class __SYCL_EXPORT handler {
16351637
unpack<KernelName, KernelType, PropertiesT,
16361638
detail::KernelLambdaHasKernelHandlerArgT<KernelType,
16371639
ElementType>::value>(
1638-
KernelFunc, [&](auto Unpacker, auto... args) {
1640+
KernelFunc, [&](auto Unpacker, auto &&...args) {
16391641
Unpacker.template kernel_parallel_for_unpack<KernelName, ElementType,
1640-
KernelType>(args...);
1642+
KernelType>(
1643+
std::forward<decltype(args)>(args)...);
16411644
});
16421645
}
16431646

@@ -1648,9 +1651,10 @@ class __SYCL_EXPORT handler {
16481651
unpack<KernelName, KernelType, PropertiesT,
16491652
detail::KernelLambdaHasKernelHandlerArgT<KernelType,
16501653
ElementType>::value>(
1651-
KernelFunc, [&](auto Unpacker, auto... args) {
1654+
KernelFunc, [&](auto Unpacker, auto &&...args) {
16521655
Unpacker.template kernel_parallel_for_work_group_unpack<
1653-
KernelName, ElementType, KernelType>(args...);
1656+
KernelName, ElementType, KernelType>(
1657+
std::forward<decltype(args)>(args)...);
16541658
});
16551659
}
16561660

@@ -1900,21 +1904,21 @@ class __SYCL_EXPORT handler {
19001904
void parallel_for(range<1> NumWorkItems, const KernelType &KernelFunc) {
19011905
parallel_for_lambda_impl<KernelName>(
19021906
NumWorkItems, ext::oneapi::experimental::empty_properties_t{},
1903-
std::move(KernelFunc));
1907+
KernelFunc);
19041908
}
19051909

19061910
template <typename KernelName = detail::auto_name, typename KernelType>
19071911
void parallel_for(range<2> NumWorkItems, const KernelType &KernelFunc) {
19081912
parallel_for_lambda_impl<KernelName>(
19091913
NumWorkItems, ext::oneapi::experimental::empty_properties_t{},
1910-
std::move(KernelFunc));
1914+
KernelFunc);
19111915
}
19121916

19131917
template <typename KernelName = detail::auto_name, typename KernelType>
19141918
void parallel_for(range<3> NumWorkItems, const KernelType &KernelFunc) {
19151919
parallel_for_lambda_impl<KernelName>(
19161920
NumWorkItems, ext::oneapi::experimental::empty_properties_t{},
1917-
std::move(KernelFunc));
1921+
KernelFunc);
19181922
}
19191923

19201924
/// Enqueues a command to the SYCL runtime to invoke \p Func once.
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// RUN: %{build} -o %t.out
2+
// RUN: %{run} %t.out
3+
4+
#include <sycl/detail/core.hpp>
5+
6+
size_t copy_count = 0;
7+
size_t move_count = 0;
8+
9+
template <int N> class kernel {
10+
public:
11+
kernel() {};
12+
kernel(const kernel &other) { copy_count++; };
13+
kernel(kernel &&other) { ++move_count; }
14+
15+
void operator()(sycl::id<1> id) const {}
16+
void operator()(sycl::nd_item<1> id) const {}
17+
void operator()() const {}
18+
};
19+
template <int N> struct sycl::is_device_copyable<kernel<N>> : std::true_type {};
20+
21+
int main(int argc, char **argv) {
22+
sycl::queue q;
23+
24+
kernel<0> krn0;
25+
q.parallel_for(sycl::range<1>{1}, krn0);
26+
assert(copy_count == 1);
27+
assert(move_count == 0);
28+
copy_count = 0;
29+
30+
kernel<1> krn1;
31+
q.parallel_for(sycl::nd_range<1>{1, 1}, krn1);
32+
assert(copy_count == 1);
33+
assert(move_count == 0);
34+
copy_count = 0;
35+
36+
kernel<2> krn2;
37+
q.single_task(krn2);
38+
assert(copy_count == 1);
39+
assert(move_count == 0);
40+
copy_count = 0;
41+
42+
return 0;
43+
}

0 commit comments

Comments
 (0)