@@ -1208,8 +1208,8 @@ class __SYCL_EXPORT handler {
12081208 using KName = std::conditional_t <std::is_same<KernelType, NameT>::value,
12091209 decltype (Wrapper), NameWT>;
12101210
1211- kernel_parallel_for_wrapper<KName, TransformedArgType , decltype (Wrapper),
1212- PropertiesT>( Wrapper);
1211+ KernelWrapper<WrapAs::parallel_for, KName , decltype (Wrapper),
1212+ TransformedArgType, PropertiesT>:: wrap ( this , Wrapper);
12131213#ifndef __SYCL_DEVICE_ONLY__
12141214 verifyUsedKernelBundleInternal (
12151215 detail::string_view{detail::getKernelName<NameT>()});
@@ -1234,8 +1234,8 @@ class __SYCL_EXPORT handler {
12341234#ifndef __SYCL_FORCE_PARALLEL_FOR_RANGE_ROUNDING__
12351235 // If parallel_for range rounding is forced then only range rounded
12361236 // kernel is generated
1237- kernel_parallel_for_wrapper<NameT, TransformedArgType , KernelType,
1238- PropertiesT>( KernelFunc);
1237+ KernelWrapper<WrapAs::parallel_for, NameT , KernelType, TransformedArgType ,
1238+ PropertiesT>:: wrap ( this , KernelFunc);
12391239#ifndef __SYCL_DEVICE_ONLY__
12401240 verifyUsedKernelBundleInternal (
12411241 detail::string_view{detail::getKernelName<NameT>()});
@@ -1283,8 +1283,8 @@ class __SYCL_EXPORT handler {
12831283
12841284 (void )ExecutionRange;
12851285 (void )Props;
1286- kernel_parallel_for_wrapper<NameT, TransformedArgType , KernelType,
1287- PropertiesT>( KernelFunc);
1286+ KernelWrapper<WrapAs::parallel_for, NameT , KernelType, TransformedArgType ,
1287+ PropertiesT>:: wrap ( this , KernelFunc);
12881288#ifndef __SYCL_DEVICE_ONLY__
12891289 throwIfActionIsCreated ();
12901290 verifyUsedKernelBundleInternal (
@@ -1371,8 +1371,8 @@ class __SYCL_EXPORT handler {
13711371 sycl::detail::lambda_arg_type<KernelType, group<Dims>>;
13721372 (void )NumWorkGroups;
13731373 (void )Props;
1374- kernel_parallel_for_work_group_wrapper<NameT, LambdaArgType , KernelType,
1375- PropertiesT>( KernelFunc);
1374+ KernelWrapper<WrapAs::parallel_for_work_group, NameT , KernelType,
1375+ LambdaArgType, PropertiesT>:: wrap ( this , KernelFunc);
13761376#ifndef __SYCL_DEVICE_ONLY__
13771377 throwIfActionIsCreated ();
13781378 verifyUsedKernelBundleInternal (
@@ -1413,8 +1413,8 @@ class __SYCL_EXPORT handler {
14131413 (void )NumWorkGroups;
14141414 (void )WorkGroupSize;
14151415 (void )Props;
1416- kernel_parallel_for_work_group_wrapper<NameT, LambdaArgType , KernelType,
1417- PropertiesT>( KernelFunc);
1416+ KernelWrapper<WrapAs::parallel_for_work_group, NameT , KernelType,
1417+ LambdaArgType, PropertiesT>:: wrap ( this , KernelFunc);
14181418#ifndef __SYCL_DEVICE_ONLY__
14191419 throwIfActionIsCreated ();
14201420 verifyUsedKernelBundleInternal (
@@ -1554,127 +1554,79 @@ class __SYCL_EXPORT handler {
15541554#endif
15551555 }
15561556
1557- template <typename ... Props> struct KernelPropertiesUnpackerImpl {
1558- // Just pass extra Props... as template parameters to the underlying
1559- // Caller->* member functions. Don't have reflection so try to use
1560- // templates as much as possible to reduce the amount of boilerplate code
1561- // needed. All the type checks are expected to be done at the Caller's
1562- // methods side.
1563-
1564- template <typename ... TypesToForward, typename ... ArgsTy>
1565- static void kernel_single_task_unpack (handler *h, ArgsTy &&...Args) {
1566- h->kernel_single_task <TypesToForward..., Props...>(
1567- std::forward<ArgsTy>(Args)...);
1568- }
1569-
1570- template <typename ... TypesToForward, typename ... ArgsTy>
1571- static void kernel_parallel_for_unpack (handler *h, ArgsTy &&...Args) {
1572- h->kernel_parallel_for <TypesToForward..., Props...>(
1573- std::forward<ArgsTy>(Args)...);
1574- }
1575-
1576- template <typename ... TypesToForward, typename ... ArgsTy>
1577- static void kernel_parallel_for_work_group_unpack (handler *h,
1578- ArgsTy &&...Args) {
1579- h->kernel_parallel_for_work_group <TypesToForward..., Props...>(
1580- std::forward<ArgsTy>(Args)...);
1581- }
1582- };
1583-
1584- template <typename PropertiesT>
1585- struct KernelPropertiesUnpacker : public KernelPropertiesUnpackerImpl <> {
1586- // This should always fail outside the specialization below but must be
1587- // dependent to avoid failing even if not instantiated.
1588- static_assert (
1589- ext::oneapi::experimental::is_property_list<PropertiesT>::value,
1590- " Template type is not a property list." );
1591- };
1592-
1593- template <typename ... Props>
1594- struct KernelPropertiesUnpacker <
1595- ext::oneapi::experimental::detail::properties_t <Props...>>
1596- : public KernelPropertiesUnpackerImpl<Props...> {};
1597-
1598- // Helper function to
1599- //
1600- // * Make use of the KernelPropertiesUnpacker above
1601- // * Decide if we need an extra kernel_handler parameter
1557+ // The KernelWrapper below has two purposes.
16021558 //
1603- // The interface uses a \p Lambda callback to propagate that information back
1604- // to the caller as we need the caller to communicate:
1559+ // First, from SYCL 2020, Table 129 (Member functions of the `handler ` class)
1560+ // > The callable ... can optionally take a `kernel_handler` ... in
1561+ // which > case the SYCL runtime will construct an instance of
1562+ // `kernel_handler` > and pass it to the callable.
16051563 //
1606- // * Name of the method to call
1607- // * Provide explicit template type parameters for the call
1564+ // Note: "..." due to slight wording variability between
1565+ // single_task/parallel_for (e.g. only parameter vs last). This helper class
1566+ // calls `kernel_*` entry points (both hardcoded names known to FE and special
1567+ // device-specific entry point attributes) with proper arguments (with/without
1568+ // `kernel_handler` argument, depending on the signature of the SYCL kernel
1569+ // function).
16081570 //
1609- // Couldn't think of a better way to achieve both.
1610- template <typename KernelName, typename KernelType, typename PropertiesT,
1611- bool HasKernelHandlerArg, typename FuncTy>
1612- void unpack (const KernelType &KernelFunc, FuncTy Lambda) {
1613- #ifdef __SYCL_DEVICE_ONLY__
1614- detail::CheckDeviceCopyable<KernelType>();
1615- #endif // __SYCL_DEVICE_ONLY__
1616- using MergedPropertiesT =
1617- typename detail::GetMergedKernelProperties<KernelType,
1618- PropertiesT>::type;
1619- using Unpacker = KernelPropertiesUnpacker<MergedPropertiesT>;
1620- #ifndef __SYCL_DEVICE_ONLY__
1621- // If there are properties provided by get method then process them.
1622- if constexpr (ext::oneapi::experimental::detail::
1623- HasKernelPropertiesGetMethod<const KernelType &>::value) {
1624- processProperties<detail::isKernelESIMD<KernelName>()>(
1625- KernelFunc.get (ext::oneapi::experimental::properties_tag{}));
1626- }
1627- #endif
1628- if constexpr (HasKernelHandlerArg) {
1629- kernel_handler KH;
1630- Lambda (Unpacker{}, this , KernelFunc, KH);
1631- } else {
1632- Lambda (Unpacker{}, this , KernelFunc);
1633- }
1634- }
1571+ // Second, it performs a few checks and some properties processing (including
1572+ // the one provided via `sycl_ext_oneapi_kernel_properties` extension by
1573+ // embedding them into the kernel's type).
16351574
1636- // NOTE: to support kernel_handler argument in kernel lambdas, only
1637- // kernel_***_wrapper functions must be called in this code
1575+ enum class WrapAs { single_task, parallel_for, parallel_for_work_group };
16381576
16391577 template <
1640- typename KernelName, typename KernelType,
1641- typename PropertiesT = ext::oneapi::experimental::empty_properties_t >
1642- void kernel_single_task_wrapper (const KernelType &KernelFunc) {
1643- unpack<KernelName, KernelType, PropertiesT,
1644- detail::KernelLambdaHasKernelHandlerArgT<KernelType>::value>(
1645- KernelFunc, [&](auto Unpacker, auto &&...args ) {
1646- Unpacker.template kernel_single_task_unpack <KernelName, KernelType>(
1578+ WrapAs WrapAsVal, typename KernelName, typename KernelType,
1579+ typename ElementType,
1580+ typename PropertiesT = ext::oneapi::experimental::empty_properties_t ,
1581+ typename MergedPropertiesT = typename detail::GetMergedKernelProperties<
1582+ KernelType, PropertiesT>::type>
1583+ struct KernelWrapper ;
1584+ template <WrapAs WrapAsVal, typename KernelName, typename KernelType,
1585+ typename ElementType, typename PropertiesT, typename ... MergedProps>
1586+ struct KernelWrapper <
1587+ WrapAsVal, KernelName, KernelType, ElementType, PropertiesT,
1588+ ext::oneapi::experimental::detail::properties_t <MergedProps...>> {
1589+ static void wrap (handler *h, const KernelType &KernelFunc) {
1590+ #ifdef __SYCL_DEVICE_ONLY__
1591+ detail::CheckDeviceCopyable<KernelType>();
1592+ #else
1593+ // If there are properties provided by get method then process them.
1594+ if constexpr (ext::oneapi::experimental::detail::
1595+ HasKernelPropertiesGetMethod<
1596+ const KernelType &>::value) {
1597+ h->processProperties <detail::isKernelESIMD<KernelName>()>(
1598+ KernelFunc.get (ext::oneapi::experimental::properties_tag{}));
1599+ }
1600+ #endif
1601+ auto L = [&](auto &&...args ) {
1602+ if constexpr (WrapAsVal == WrapAs::single_task) {
1603+ h->kernel_single_task <KernelName, KernelType, MergedProps...>(
16471604 std::forward<decltype (args)>(args)...);
1648- });
1649- }
1650-
1651- template <
1652- typename KernelName, typename ElementType, typename KernelType,
1653- typename PropertiesT = ext::oneapi::experimental::empty_properties_t >
1654- void kernel_parallel_for_wrapper (const KernelType &KernelFunc) {
1655- unpack<KernelName, KernelType, PropertiesT,
1656- detail::KernelLambdaHasKernelHandlerArgT<KernelType,
1657- ElementType>::value>(
1658- KernelFunc, [&](auto Unpacker, auto &&...args ) {
1659- Unpacker.template kernel_parallel_for_unpack <KernelName, ElementType,
1660- KernelType>(
1605+ } else if constexpr (WrapAsVal == WrapAs::parallel_for) {
1606+ h->kernel_parallel_for <KernelName, ElementType, KernelType,
1607+ MergedProps...>(
16611608 std::forward<decltype (args)>(args)...);
1662- });
1663- }
1664-
1665- template <
1666- typename KernelName, typename ElementType, typename KernelType,
1667- typename PropertiesT = ext::oneapi::experimental::empty_properties_t >
1668- void kernel_parallel_for_work_group_wrapper (const KernelType &KernelFunc) {
1669- unpack<KernelName, KernelType, PropertiesT,
1670- detail::KernelLambdaHasKernelHandlerArgT<KernelType,
1671- ElementType>::value>(
1672- KernelFunc, [&](auto Unpacker, auto &&...args ) {
1673- Unpacker.template kernel_parallel_for_work_group_unpack <
1674- KernelName, ElementType, KernelType>(
1609+ } else if constexpr (WrapAsVal == WrapAs::parallel_for_work_group) {
1610+ h->kernel_parallel_for_work_group <KernelName, ElementType, KernelType,
1611+ MergedProps...>(
16751612 std::forward<decltype (args)>(args)...);
1676- });
1677- }
1613+ } else {
1614+ // Always false, but template-dependent.
1615+ static_assert (WrapAsVal != WrapAsVal, " Unexpected WrapAsVal" );
1616+ }
1617+ };
1618+ if constexpr (detail::KernelLambdaHasKernelHandlerArgT<
1619+ KernelType, ElementType>::value) {
1620+ kernel_handler KH;
1621+ L (KernelFunc, KH);
1622+ } else {
1623+ L (KernelFunc);
1624+ }
1625+ }
1626+ };
1627+
1628+ // NOTE: to support kernel_handler argument in kernel lambdas, only
1629+ // KernelWrapper<...>::wrap() must be called in this code.
16781630
16791631 // / Defines and invokes a SYCL kernel function as a function object type.
16801632 // /
@@ -1694,7 +1646,8 @@ class __SYCL_EXPORT handler {
16941646 using NameT =
16951647 typename detail::get_kernel_name_t <KernelName, KernelType>::name;
16961648
1697- kernel_single_task_wrapper<NameT, KernelType, PropertiesT>(KernelFunc);
1649+ KernelWrapper<WrapAs::single_task, NameT, KernelType, void ,
1650+ PropertiesT>::wrap (this , KernelFunc);
16981651#ifndef __SYCL_DEVICE_ONLY__
16991652 throwIfActionIsCreated ();
17001653 throwOnKernelParameterMisuse<KernelName, KernelType>();
@@ -1997,7 +1950,8 @@ class __SYCL_EXPORT handler {
19971950 typename TransformUserItemType<Dims, LambdaArgType>::type>;
19981951 (void )NumWorkItems;
19991952 (void )WorkItemOffset;
2000- kernel_parallel_for_wrapper<NameT, TransformedArgType>(KernelFunc);
1953+ KernelWrapper<WrapAs::parallel_for, NameT, KernelType,
1954+ TransformedArgType>::wrap (this , KernelFunc);
20011955#ifndef __SYCL_DEVICE_ONLY__
20021956 throwIfActionIsCreated ();
20031957 verifyUsedKernelBundleInternal (
@@ -2173,7 +2127,8 @@ class __SYCL_EXPORT handler {
21732127 using LambdaArgType = sycl::detail::lambda_arg_type<KernelType, item<Dims>>;
21742128 (void )Kernel;
21752129 (void )NumWorkItems;
2176- kernel_parallel_for_wrapper<NameT, LambdaArgType>(KernelFunc);
2130+ KernelWrapper<WrapAs::parallel_for, NameT, KernelType, LambdaArgType>::wrap (
2131+ this , KernelFunc);
21772132#ifndef __SYCL_DEVICE_ONLY__
21782133 throwIfActionIsCreated ();
21792134 verifyUsedKernelBundleInternal (
@@ -2211,7 +2166,8 @@ class __SYCL_EXPORT handler {
22112166 (void )Kernel;
22122167 (void )NumWorkItems;
22132168 (void )WorkItemOffset;
2214- kernel_parallel_for_wrapper<NameT, LambdaArgType>(KernelFunc);
2169+ KernelWrapper<WrapAs::parallel_for, NameT, KernelType, LambdaArgType>::wrap (
2170+ this , KernelFunc);
22152171#ifndef __SYCL_DEVICE_ONLY__
22162172 throwIfActionIsCreated ();
22172173 // Ignore any set kernel bundles and use the one associated with the kernel
@@ -2250,7 +2206,8 @@ class __SYCL_EXPORT handler {
22502206 sycl::detail::lambda_arg_type<KernelType, nd_item<Dims>>;
22512207 (void )Kernel;
22522208 (void )NDRange;
2253- kernel_parallel_for_wrapper<NameT, LambdaArgType>(KernelFunc);
2209+ KernelWrapper<WrapAs::parallel_for, NameT, KernelType, LambdaArgType>::wrap (
2210+ this , KernelFunc);
22542211#ifndef __SYCL_DEVICE_ONLY__
22552212 throwIfActionIsCreated ();
22562213 // Ignore any set kernel bundles and use the one associated with the kernel
@@ -2293,7 +2250,8 @@ class __SYCL_EXPORT handler {
22932250 sycl::detail::lambda_arg_type<KernelType, group<Dims>>;
22942251 (void )Kernel;
22952252 (void )NumWorkGroups;
2296- kernel_parallel_for_work_group_wrapper<NameT, LambdaArgType>(KernelFunc);
2253+ KernelWrapper<WrapAs::parallel_for_work_group, NameT, KernelType,
2254+ LambdaArgType>::wrap (this , KernelFunc);
22972255#ifndef __SYCL_DEVICE_ONLY__
22982256 throwIfActionIsCreated ();
22992257 // Ignore any set kernel bundles and use the one associated with the kernel
@@ -2335,7 +2293,8 @@ class __SYCL_EXPORT handler {
23352293 (void )Kernel;
23362294 (void )NumWorkGroups;
23372295 (void )WorkGroupSize;
2338- kernel_parallel_for_work_group_wrapper<NameT, LambdaArgType>(KernelFunc);
2296+ KernelWrapper<WrapAs::parallel_for_work_group, NameT, KernelType,
2297+ LambdaArgType>::wrap (this , KernelFunc);
23392298#ifndef __SYCL_DEVICE_ONLY__
23402299 throwIfActionIsCreated ();
23412300 // Ignore any set kernel bundles and use the one associated with the kernel
0 commit comments