Skip to content

Commit a90a0ea

Browse files
Enable move semantic for kernels.
1 parent 5e89d2f commit a90a0ea

File tree

3 files changed

+57
-16
lines changed

3 files changed

+57
-16
lines changed

sycl/include/sycl/detail/cg_types.hpp

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,10 +239,36 @@ class HostKernel : public HostKernelBase {
239239
class HostKernelRefBase : public HostKernelBase {
240240
public:
241241
virtual std::shared_ptr<HostKernelBase> takeOrCopyOwnership() const = 0;
242+
#ifndef __INTEL_PREVIEW_BREAKING_CHANGES
243+
// This function can't be called from old user code, because there is no
244+
// HostKernelRef in old user code. So, make it empty.
245+
void InstantiateKernelOnHost() override {}
246+
#endif
242247
};
243248

244-
template <class KernelType, class KernelArgType, int Dims>
249+
// primary template for movable lambda objects
250+
template <class KernelType, class KernelTypeUniversalRef, class KernelArgType, int Dims>
245251
class HostKernelRef : public HostKernelRefBase {
252+
KernelType &&MKernel;
253+
254+
public:
255+
HostKernelRef(KernelType &&Kernel) : MKernel(std::move(Kernel)) {}
256+
257+
virtual char *getPtr() override {
258+
return const_cast<char *>(reinterpret_cast<const char *>(&MKernel));
259+
}
260+
virtual std::shared_ptr<HostKernelBase> takeOrCopyOwnership() const override {
261+
std::shared_ptr<HostKernelBase> Kernel;
262+
Kernel.reset(new HostKernel<KernelType, KernelArgType, Dims>(std::move(MKernel)));
263+
return Kernel;
264+
}
265+
266+
~HostKernelRef() noexcept override = default;
267+
};
268+
269+
// specialization for copyable lambda objects
270+
template <class KernelType, class KernelTypeUniversalRef, class KernelArgType, int Dims>
271+
class HostKernelRef<KernelType, KernelTypeUniversalRef&, KernelArgType, Dims> : public HostKernelRefBase {
246272
const KernelType &MKernel;
247273

248274
public:

sycl/include/sycl/khr/free_function_commands.hpp

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -148,45 +148,56 @@ void launch_grouped(handler &h, range<3> r, range<3> size,
148148
h.parallel_for(nd_range<3>(r, size), k);
149149
}
150150

151-
template <typename KernelType>
151+
template <typename KernelType,
152+
// overload of launch_grouped for sycl::kernel must be preferred
153+
typename = typename std::enable_if<!std::is_same<typename std::decay<KernelType>::type,
154+
sycl::kernel>::value
155+
>::type>
152156
void launch_grouped(const queue &q, range<1> r, range<1> size,
153-
const KernelType &k,
157+
KernelType &&k,
154158
const sycl::detail::code_location &codeLoc =
155159
sycl::detail::code_location::current()) {
156160
#ifdef __DPCPP_ENABLE_UNFINISHED_NO_CGH_SUBMIT
157161
detail::submit_kernel_direct(q,
158162
ext::oneapi::experimental::empty_properties_t{},
159-
nd_range<1>(r, size), k);
163+
nd_range<1>(r, size), std::forward<KernelType>(k));
160164
#else
161165
submit(
162166
q, [&](handler &h) { launch_grouped<KernelType>(h, r, size, k); },
163167
codeLoc);
164168
#endif
165169
}
166-
template <typename KernelType>
170+
template <typename KernelType,
171+
typename = typename std::enable_if<!std::is_same<typename std::decay<KernelType>::type,
172+
sycl::kernel>::value
173+
>::type>
167174
void launch_grouped(const queue &q, range<2> r, range<2> size,
168-
const KernelType &k,
175+
KernelType &&k,
169176
const sycl::detail::code_location &codeLoc =
170177
sycl::detail::code_location::current()) {
171178
#ifdef __DPCPP_ENABLE_UNFINISHED_NO_CGH_SUBMIT
172179
detail::submit_kernel_direct(q,
173180
ext::oneapi::experimental::empty_properties_t{},
174-
nd_range<2>(r, size), k);
181+
nd_range<2>(r, size),
182+
std::forward<KernelType>(k));
175183
#else
176184
submit(
177185
q, [&](handler &h) { launch_grouped<KernelType>(h, r, size, k); },
178186
codeLoc);
179187
#endif
180188
}
181-
template <typename KernelType>
189+
template <typename KernelType,
190+
typename = typename std::enable_if<!std::is_same<typename std::decay<KernelType>::type,
191+
sycl::kernel>::value
192+
>::type>
182193
void launch_grouped(const queue &q, range<3> r, range<3> size,
183-
const KernelType &k,
194+
KernelType &&k,
184195
const sycl::detail::code_location &codeLoc =
185196
sycl::detail::code_location::current()) {
186197
#ifdef __DPCPP_ENABLE_UNFINISHED_NO_CGH_SUBMIT
187198
detail::submit_kernel_direct(q,
188199
ext::oneapi::experimental::empty_properties_t{},
189-
nd_range<3>(r, size), k);
200+
nd_range<3>(r, size), std::forward<KernelType>(k));
190201
#else
191202
submit(
192203
q, [&](handler &h) { launch_grouped<KernelType>(h, r, size, k); },

sycl/include/sycl/queue.hpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -157,10 +157,10 @@ class __SYCL_EXPORT SubmissionInfo {
157157
} // namespace v1
158158

159159
template <typename KernelName = detail::auto_name, bool EventNeeded = false,
160-
typename PropertiesT, typename KernelType, int Dims>
160+
typename PropertiesT, typename KernelTypeUniversalRef, int Dims>
161161
auto submit_kernel_direct(
162162
const queue &Queue, PropertiesT Props, const nd_range<Dims> &Range,
163-
const KernelType &KernelFunc,
163+
KernelTypeUniversalRef &&KernelFunc,
164164
const detail::code_location &CodeLoc = detail::code_location::current()) {
165165
// TODO Properties not supported yet
166166
(void)Props;
@@ -170,6 +170,9 @@ auto submit_kernel_direct(
170170
"Setting properties not supported yet for no-CGH kernel submit.");
171171
detail::tls_code_loc_t TlsCodeLocCapture(CodeLoc);
172172

173+
using KernelType =
174+
std::remove_const_t<std::remove_reference_t<KernelTypeUniversalRef>>;
175+
173176
using NameT =
174177
typename detail::get_kernel_name_t<KernelName, KernelType>::name;
175178
using LambdaArgType =
@@ -180,7 +183,11 @@ auto submit_kernel_direct(
180183
"must be either sycl::nd_item or be convertible from sycl::nd_item");
181184
using TransformedArgType = sycl::nd_item<Dims>;
182185

183-
HostKernelRef<KernelType, TransformedArgType, Dims> HostKernel(KernelFunc);
186+
detail::KernelWrapper<detail::WrapAs::parallel_for, NameT, KernelType,
187+
TransformedArgType, PropertiesT>::wrap(KernelFunc);
188+
189+
HostKernelRef<KernelType, KernelTypeUniversalRef, TransformedArgType, Dims>
190+
HostKernel(std::forward<KernelTypeUniversalRef>(KernelFunc));
184191

185192
// Instantiating the kernel on the host improves debugging.
186193
// Passing this pointer to another translation unit prevents optimization.
@@ -193,9 +200,6 @@ auto submit_kernel_direct(
193200
detail::DeviceKernelInfo *DeviceKernelInfoPtr =
194201
&detail::getDeviceKernelInfo<NameT>();
195202

196-
detail::KernelWrapper<detail::WrapAs::parallel_for, NameT, KernelType,
197-
TransformedArgType, PropertiesT>::wrap(KernelFunc);
198-
199203
if constexpr (EventNeeded) {
200204
return submit_kernel_direct_with_event_impl(
201205
Queue, Range, HostKernel, DeviceKernelInfoPtr,

0 commit comments

Comments
 (0)