Skip to content

Commit 514ac88

Browse files
[SYCL][NFCI] More refactoring around "kernel wrapping"
Follows what I started in intel#17838.
1 parent b38eb60 commit 514ac88

File tree

1 file changed

+78
-194
lines changed

1 file changed

+78
-194
lines changed

sycl/include/sycl/handler.hpp

Lines changed: 78 additions & 194 deletions
Original file line numberDiff line numberDiff line change
@@ -1267,10 +1267,6 @@ class __SYCL_EXPORT handler {
12671267
typename PropertiesT>
12681268
void parallel_for_impl(nd_range<Dims> ExecutionRange, PropertiesT Props,
12691269
const KernelType &KernelFunc) {
1270-
// TODO: Properties may change the kernel function, so in order to avoid
1271-
// conflicts they should be included in the name.
1272-
using NameT =
1273-
typename detail::get_kernel_name_t<KernelName, KernelType>::name;
12741270
using LambdaArgType =
12751271
sycl::detail::lambda_arg_type<KernelType, nd_item<Dims>>;
12761272
static_assert(
@@ -1279,21 +1275,8 @@ class __SYCL_EXPORT handler {
12791275
"must be either sycl::nd_item or be convertible from sycl::nd_item");
12801276
using TransformedArgType = sycl::nd_item<Dims>;
12811277

1282-
(void)ExecutionRange;
1283-
(void)Props;
1284-
KernelWrapper<WrapAs::parallel_for, NameT, KernelType, TransformedArgType,
1285-
PropertiesT>::wrap(this, KernelFunc);
1286-
#ifndef __SYCL_DEVICE_ONLY__
1287-
throwIfActionIsCreated();
1288-
verifyUsedKernelBundleInternal(
1289-
detail::string_view{detail::getKernelName<NameT>()});
1290-
detail::checkValueRange<Dims>(ExecutionRange);
1291-
setNDRangeDescriptor(std::move(ExecutionRange));
1292-
processProperties<detail::isKernelESIMD<NameT>(), PropertiesT>(Props);
1293-
StoreLambda<NameT, KernelType, Dims, TransformedArgType>(
1294-
std::move(KernelFunc));
1295-
setType(detail::CGType::Kernel);
1296-
#endif
1278+
wrap_kernel<WrapAs::parallel_for, KernelName, TransformedArgType, Dims>(
1279+
KernelFunc, nullptr /*Kernel*/, Props, ExecutionRange);
12971280
}
12981281

12991282
/// Defines and invokes a SYCL kernel function for the specified range.
@@ -1361,26 +1344,12 @@ class __SYCL_EXPORT handler {
13611344
void parallel_for_work_group_lambda_impl(range<Dims> NumWorkGroups,
13621345
PropertiesT Props,
13631346
const KernelType &KernelFunc) {
1364-
// TODO: Properties may change the kernel function, so in order to avoid
1365-
// conflicts they should be included in the name.
1366-
using NameT =
1367-
typename detail::get_kernel_name_t<KernelName, KernelType>::name;
13681347
using LambdaArgType =
13691348
sycl::detail::lambda_arg_type<KernelType, group<Dims>>;
1370-
(void)NumWorkGroups;
1371-
(void)Props;
1372-
KernelWrapper<WrapAs::parallel_for_work_group, NameT, KernelType,
1373-
LambdaArgType, PropertiesT>::wrap(this, KernelFunc);
1374-
#ifndef __SYCL_DEVICE_ONLY__
1375-
throwIfActionIsCreated();
1376-
verifyUsedKernelBundleInternal(
1377-
detail::string_view{detail::getKernelName<NameT>()});
1378-
processProperties<detail::isKernelESIMD<NameT>(), PropertiesT>(Props);
1379-
detail::checkValueRange<Dims>(NumWorkGroups);
1380-
setNDRangeDescriptor(NumWorkGroups, /*SetNumWorkGroups=*/true);
1381-
StoreLambda<NameT, KernelType, Dims, LambdaArgType>(std::move(KernelFunc));
1382-
setType(detail::CGType::Kernel);
1383-
#endif // __SYCL_DEVICE_ONLY__
1349+
wrap_kernel<WrapAs::parallel_for_work_group, KernelName, LambdaArgType,
1350+
Dims,
1351+
/*SetNumWorkGroups=*/true>(KernelFunc, nullptr /*Kernel*/,
1352+
Props, NumWorkGroups);
13841353
}
13851354

13861355
/// Hierarchical kernel invocation method of a kernel defined as a lambda
@@ -1402,29 +1371,12 @@ class __SYCL_EXPORT handler {
14021371
range<Dims> WorkGroupSize,
14031372
PropertiesT Props,
14041373
const KernelType &KernelFunc) {
1405-
// TODO: Properties may change the kernel function, so in order to avoid
1406-
// conflicts they should be included in the name.
1407-
using NameT =
1408-
typename detail::get_kernel_name_t<KernelName, KernelType>::name;
14091374
using LambdaArgType =
14101375
sycl::detail::lambda_arg_type<KernelType, group<Dims>>;
1411-
(void)NumWorkGroups;
1412-
(void)WorkGroupSize;
1413-
(void)Props;
1414-
KernelWrapper<WrapAs::parallel_for_work_group, NameT, KernelType,
1415-
LambdaArgType, PropertiesT>::wrap(this, KernelFunc);
1416-
#ifndef __SYCL_DEVICE_ONLY__
1417-
throwIfActionIsCreated();
1418-
verifyUsedKernelBundleInternal(
1419-
detail::string_view{detail::getKernelName<NameT>()});
1420-
processProperties<detail::isKernelESIMD<NameT>(), PropertiesT>(Props);
14211376
nd_range<Dims> ExecRange =
14221377
nd_range<Dims>(NumWorkGroups * WorkGroupSize, WorkGroupSize);
1423-
detail::checkValueRange<Dims>(ExecRange);
1424-
setNDRangeDescriptor(std::move(ExecRange));
1425-
StoreLambda<NameT, KernelType, Dims, LambdaArgType>(std::move(KernelFunc));
1426-
setType(detail::CGType::Kernel);
1427-
#endif // __SYCL_DEVICE_ONLY__
1378+
wrap_kernel<WrapAs::parallel_for_work_group, KernelName, LambdaArgType,
1379+
Dims>(KernelFunc, nullptr /*Kernel*/, Props, ExecRange);
14281380
}
14291381

14301382
#ifdef SYCL_LANGUAGE_VERSION
@@ -1636,6 +1588,59 @@ class __SYCL_EXPORT handler {
16361588
}
16371589
};
16381590

1591+
template <
1592+
WrapAs WrapAsVal, typename KernelName, typename ElementType = void,
1593+
int Dims = 1, bool SetNumWorkGroups = false,
1594+
typename PropertiesT = ext::oneapi::experimental::empty_properties_t,
1595+
typename KernelType, typename MaybeKernelTy, typename... RangeParams>
1596+
void wrap_kernel(const KernelType &KernelFunc, MaybeKernelTy &&MaybeKernel,
1597+
const PropertiesT &Props,
1598+
[[maybe_unused]] RangeParams &&...params) {
1599+
// TODO: Properties may change the kernel function, so in order to avoid
1600+
// conflicts they should be included in the name.
1601+
using NameT =
1602+
typename detail::get_kernel_name_t<KernelName, KernelType>::name;
1603+
(void)Props;
1604+
(void)MaybeKernel;
1605+
static_assert(std::is_same_v<MaybeKernelTy, kernel> ||
1606+
std::is_same_v<MaybeKernelTy, std::nullptr_t>);
1607+
KernelWrapper<WrapAsVal, NameT, KernelType, ElementType, PropertiesT>::wrap(
1608+
this, KernelFunc);
1609+
#ifndef __SYCL_DEVICE_ONLY__
1610+
throwIfActionIsCreated();
1611+
if constexpr (std::is_same_v<MaybeKernelTy, kernel>) {
1612+
// Ignore any set kernel bundles and use the one associated with the
1613+
// kernel.
1614+
setHandlerKernelBundle(MaybeKernel);
1615+
}
1616+
verifyUsedKernelBundleInternal(
1617+
detail::string_view{detail::getKernelName<NameT>()});
1618+
setType(detail::CGType::Kernel);
1619+
1620+
detail::checkValueRange<Dims>(params...);
1621+
if constexpr (SetNumWorkGroups) {
1622+
setNDRangeDescriptor(std::move(params)...,
1623+
/*SetNumWorkGroups=*/true);
1624+
} else {
1625+
setNDRangeDescriptor(std::move(params)...);
1626+
}
1627+
1628+
if constexpr (std::is_same_v<MaybeKernelTy, std::nullptr_t>) {
1629+
StoreLambda<NameT, KernelType, Dims, ElementType>(std::move(KernelFunc));
1630+
} else {
1631+
MKernel = detail::getSyclObjImpl(std::move(MaybeKernel));
1632+
if (!lambdaAndKernelHaveEqualName<NameT>()) {
1633+
extractArgsAndReqs();
1634+
MKernelName = getKernelName();
1635+
} else {
1636+
StoreLambda<NameT, KernelType, Dims, ElementType>(
1637+
std::move(KernelFunc));
1638+
}
1639+
}
1640+
processProperties<detail::isKernelESIMD<NameT>(), PropertiesT>(Props);
1641+
#endif
1642+
}
1643+
16391644
// NOTE: to support kernel_handler argument in kernel lambdas, only
16401645
// KernelWrapper<...>::wrap() must be called in this code.
16411646

@@ -1651,25 +1656,10 @@ class __SYCL_EXPORT handler {
16511656
typename PropertiesT = ext::oneapi::experimental::empty_properties_t>
16521657
void single_task_lambda_impl(PropertiesT Props,
16531658
const KernelType &KernelFunc) {
1654-
(void)Props;
1655-
// TODO: Properties may change the kernel function, so in order to avoid
1656-
// conflicts they should be included in the name.
1657-
using NameT =
1658-
typename detail::get_kernel_name_t<KernelName, KernelType>::name;
1659-
1660-
KernelWrapper<WrapAs::single_task, NameT, KernelType, void,
1661-
PropertiesT>::wrap(this, KernelFunc);
1659+
wrap_kernel<WrapAs::single_task, KernelName>(KernelFunc, nullptr /*Kernel*/,
1660+
Props, range<1>{1});
16621661
#ifndef __SYCL_DEVICE_ONLY__
1663-
throwIfActionIsCreated();
16641662
throwOnKernelParameterMisuse<KernelName, KernelType>();
1665-
verifyUsedKernelBundleInternal(
1666-
detail::string_view{detail::getKernelName<NameT>()});
1667-
// No need to check if range is out of INT_MAX limits as it's compile-time
1668-
// known constant.
1669-
setNDRangeDescriptor(range<1>{1});
1670-
processProperties<detail::isKernelESIMD<NameT>(), PropertiesT>(Props);
1671-
StoreLambda<NameT, KernelType, /*Dims*/ 1, void>(KernelFunc);
1672-
setType(detail::CGType::Kernel);
16731663
#endif
16741664
}
16751665

@@ -1953,26 +1943,13 @@ class __SYCL_EXPORT handler {
19531943
__SYCL2020_DEPRECATED("offsets are deprecated in SYCL2020")
19541944
void parallel_for(range<Dims> NumWorkItems, id<Dims> WorkItemOffset,
19551945
const KernelType &KernelFunc) {
1956-
using NameT =
1957-
typename detail::get_kernel_name_t<KernelName, KernelType>::name;
19581946
using LambdaArgType = sycl::detail::lambda_arg_type<KernelType, item<Dims>>;
19591947
using TransformedArgType = std::conditional_t<
19601948
std::is_integral<LambdaArgType>::value && Dims == 1, item<Dims>,
19611949
typename TransformUserItemType<Dims, LambdaArgType>::type>;
1962-
(void)NumWorkItems;
1963-
(void)WorkItemOffset;
1964-
KernelWrapper<WrapAs::parallel_for, NameT, KernelType,
1965-
TransformedArgType>::wrap(this, KernelFunc);
1966-
#ifndef __SYCL_DEVICE_ONLY__
1967-
throwIfActionIsCreated();
1968-
verifyUsedKernelBundleInternal(
1969-
detail::string_view{detail::getKernelName<NameT>()});
1970-
detail::checkValueRange<Dims>(NumWorkItems, WorkItemOffset);
1971-
setNDRangeDescriptor(std::move(NumWorkItems), std::move(WorkItemOffset));
1972-
StoreLambda<NameT, KernelType, Dims, TransformedArgType>(
1973-
std::move(KernelFunc));
1974-
setType(detail::CGType::Kernel);
1975-
#endif
1950+
wrap_kernel<WrapAs::parallel_for, KernelName, TransformedArgType, Dims>(
1951+
KernelFunc, nullptr /*Kernel*/, {} /*Props*/, NumWorkItems,
1952+
WorkItemOffset);
19761953
}
19771954

19781955
/// Hierarchical kernel invocation method of a kernel defined as a lambda
@@ -2133,28 +2110,9 @@ class __SYCL_EXPORT handler {
21332110
const KernelType &KernelFunc) {
21342111
// Ignore any set kernel bundles and use the one associated with the kernel
21352112
setHandlerKernelBundle(Kernel);
2136-
using NameT =
2137-
typename detail::get_kernel_name_t<KernelName, KernelType>::name;
21382113
using LambdaArgType = sycl::detail::lambda_arg_type<KernelType, item<Dims>>;
2139-
(void)Kernel;
2140-
(void)NumWorkItems;
2141-
KernelWrapper<WrapAs::parallel_for, NameT, KernelType, LambdaArgType>::wrap(
2142-
this, KernelFunc);
2143-
#ifndef __SYCL_DEVICE_ONLY__
2144-
throwIfActionIsCreated();
2145-
verifyUsedKernelBundleInternal(
2146-
detail::string_view{detail::getKernelName<NameT>()});
2147-
detail::checkValueRange<Dims>(NumWorkItems);
2148-
setNDRangeDescriptor(std::move(NumWorkItems));
2149-
MKernel = detail::getSyclObjImpl(std::move(Kernel));
2150-
setType(detail::CGType::Kernel);
2151-
if (!lambdaAndKernelHaveEqualName<NameT>()) {
2152-
extractArgsAndReqs();
2153-
MKernelName = getKernelName();
2154-
} else
2155-
StoreLambda<NameT, KernelType, Dims, LambdaArgType>(
2156-
std::move(KernelFunc));
2157-
#endif
2114+
wrap_kernel<WrapAs::parallel_for, KernelName, LambdaArgType, Dims>(
2115+
KernelFunc, Kernel, {} /*Props*/, NumWorkItems);
21582116
}
21592117

21602118
/// Defines and invokes a SYCL kernel function for the specified range and
@@ -2171,31 +2129,9 @@ class __SYCL_EXPORT handler {
21712129
__SYCL2020_DEPRECATED("offsets are deprecated in SYCL 2020")
21722130
void parallel_for(kernel Kernel, range<Dims> NumWorkItems,
21732131
id<Dims> WorkItemOffset, const KernelType &KernelFunc) {
2174-
using NameT =
2175-
typename detail::get_kernel_name_t<KernelName, KernelType>::name;
21762132
using LambdaArgType = sycl::detail::lambda_arg_type<KernelType, item<Dims>>;
2177-
(void)Kernel;
2178-
(void)NumWorkItems;
2179-
(void)WorkItemOffset;
2180-
KernelWrapper<WrapAs::parallel_for, NameT, KernelType, LambdaArgType>::wrap(
2181-
this, KernelFunc);
2182-
#ifndef __SYCL_DEVICE_ONLY__
2183-
throwIfActionIsCreated();
2184-
// Ignore any set kernel bundles and use the one associated with the kernel
2185-
setHandlerKernelBundle(Kernel);
2186-
verifyUsedKernelBundleInternal(
2187-
detail::string_view{detail::getKernelName<NameT>()});
2188-
detail::checkValueRange<Dims>(NumWorkItems, WorkItemOffset);
2189-
setNDRangeDescriptor(std::move(NumWorkItems), std::move(WorkItemOffset));
2190-
MKernel = detail::getSyclObjImpl(std::move(Kernel));
2191-
setType(detail::CGType::Kernel);
2192-
if (!lambdaAndKernelHaveEqualName<NameT>()) {
2193-
extractArgsAndReqs();
2194-
MKernelName = getKernelName();
2195-
} else
2196-
StoreLambda<NameT, KernelType, Dims, LambdaArgType>(
2197-
std::move(KernelFunc));
2198-
#endif
2133+
wrap_kernel<WrapAs::parallel_for, KernelName, LambdaArgType, Dims>(
2134+
KernelFunc, Kernel, {} /*Props*/, NumWorkItems, WorkItemOffset);
21992135
}
22002136

22012137
/// Defines and invokes a SYCL kernel function for the specified range and
@@ -2211,31 +2147,10 @@ class __SYCL_EXPORT handler {
22112147
int Dims>
22122148
void parallel_for(kernel Kernel, nd_range<Dims> NDRange,
22132149
const KernelType &KernelFunc) {
2214-
using NameT =
2215-
typename detail::get_kernel_name_t<KernelName, KernelType>::name;
22162150
using LambdaArgType =
22172151
sycl::detail::lambda_arg_type<KernelType, nd_item<Dims>>;
2218-
(void)Kernel;
2219-
(void)NDRange;
2220-
KernelWrapper<WrapAs::parallel_for, NameT, KernelType, LambdaArgType>::wrap(
2221-
this, KernelFunc);
2222-
#ifndef __SYCL_DEVICE_ONLY__
2223-
throwIfActionIsCreated();
2224-
// Ignore any set kernel bundles and use the one associated with the kernel
2225-
setHandlerKernelBundle(Kernel);
2226-
verifyUsedKernelBundleInternal(
2227-
detail::string_view{detail::getKernelName<NameT>()});
2228-
detail::checkValueRange<Dims>(NDRange);
2229-
setNDRangeDescriptor(std::move(NDRange));
2230-
MKernel = detail::getSyclObjImpl(std::move(Kernel));
2231-
setType(detail::CGType::Kernel);
2232-
if (!lambdaAndKernelHaveEqualName<NameT>()) {
2233-
extractArgsAndReqs();
2234-
MKernelName = getKernelName();
2235-
} else
2236-
StoreLambda<NameT, KernelType, Dims, LambdaArgType>(
2237-
std::move(KernelFunc));
2238-
#endif
2152+
wrap_kernel<WrapAs::parallel_for, KernelName, LambdaArgType, Dims>(
2153+
KernelFunc, Kernel, {} /*Props*/, NDRange);
22392154
}
22402155

22412156
/// Hierarchical kernel invocation method of a kernel.
@@ -2255,26 +2170,12 @@ class __SYCL_EXPORT handler {
22552170
int Dims>
22562171
void parallel_for_work_group(kernel Kernel, range<Dims> NumWorkGroups,
22572172
const KernelType &KernelFunc) {
2258-
using NameT =
2259-
typename detail::get_kernel_name_t<KernelName, KernelType>::name;
22602173
using LambdaArgType =
22612174
sycl::detail::lambda_arg_type<KernelType, group<Dims>>;
2262-
(void)Kernel;
2263-
(void)NumWorkGroups;
2264-
KernelWrapper<WrapAs::parallel_for_work_group, NameT, KernelType,
2265-
LambdaArgType>::wrap(this, KernelFunc);
2266-
#ifndef __SYCL_DEVICE_ONLY__
2267-
throwIfActionIsCreated();
2268-
// Ignore any set kernel bundles and use the one associated with the kernel
2269-
setHandlerKernelBundle(Kernel);
2270-
verifyUsedKernelBundleInternal(
2271-
detail::string_view{detail::getKernelName<NameT>()});
2272-
detail::checkValueRange<Dims>(NumWorkGroups);
2273-
setNDRangeDescriptor(NumWorkGroups, /*SetNumWorkGroups=*/true);
2274-
MKernel = detail::getSyclObjImpl(std::move(Kernel));
2275-
StoreLambda<NameT, KernelType, Dims, LambdaArgType>(std::move(KernelFunc));
2276-
setType(detail::CGType::Kernel);
2277-
#endif // __SYCL_DEVICE_ONLY__
2175+
wrap_kernel<WrapAs::parallel_for_work_group, KernelName, LambdaArgType,
2176+
Dims,
2177+
/*SetNumWorkGroups*/ true>(KernelFunc, Kernel, {} /*Props*/,
2178+
NumWorkGroups);
22782179
}
22792180

22802181
/// Hierarchical kernel invocation method of a kernel.
@@ -2297,29 +2198,12 @@ class __SYCL_EXPORT handler {
22972198
void parallel_for_work_group(kernel Kernel, range<Dims> NumWorkGroups,
22982199
range<Dims> WorkGroupSize,
22992200
const KernelType &KernelFunc) {
2300-
using NameT =
2301-
typename detail::get_kernel_name_t<KernelName, KernelType>::name;
23022201
using LambdaArgType =
23032202
sycl::detail::lambda_arg_type<KernelType, group<Dims>>;
2304-
(void)Kernel;
2305-
(void)NumWorkGroups;
2306-
(void)WorkGroupSize;
2307-
KernelWrapper<WrapAs::parallel_for_work_group, NameT, KernelType,
2308-
LambdaArgType>::wrap(this, KernelFunc);
2309-
#ifndef __SYCL_DEVICE_ONLY__
2310-
throwIfActionIsCreated();
2311-
// Ignore any set kernel bundles and use the one associated with the kernel
2312-
setHandlerKernelBundle(Kernel);
2313-
verifyUsedKernelBundleInternal(
2314-
detail::string_view{detail::getKernelName<NameT>()});
23152203
nd_range<Dims> ExecRange =
23162204
nd_range<Dims>(NumWorkGroups * WorkGroupSize, WorkGroupSize);
2317-
detail::checkValueRange<Dims>(ExecRange);
2318-
setNDRangeDescriptor(std::move(ExecRange));
2319-
MKernel = detail::getSyclObjImpl(std::move(Kernel));
2320-
StoreLambda<NameT, KernelType, Dims, LambdaArgType>(std::move(KernelFunc));
2321-
setType(detail::CGType::Kernel);
2322-
#endif // __SYCL_DEVICE_ONLY__
2205+
wrap_kernel<WrapAs::parallel_for_work_group, KernelName, LambdaArgType,
2206+
Dims>(KernelFunc, Kernel, {} /*Props*/, ExecRange);
23232207
}
23242208

23252209
template <typename KernelName = detail::auto_name, typename KernelType,

0 commit comments

Comments
 (0)