diff --git a/sycl/include/sycl/handler.hpp b/sycl/include/sycl/handler.hpp index f787b5beb73cc..0bd33aac26c2b 100644 --- a/sycl/include/sycl/handler.hpp +++ b/sycl/include/sycl/handler.hpp @@ -1267,10 +1267,6 @@ class __SYCL_EXPORT handler { typename PropertiesT> void parallel_for_impl(nd_range ExecutionRange, PropertiesT Props, const KernelType &KernelFunc) { - // TODO: Properties may change the kernel function, so in order to avoid - // conflicts they should be included in the name. - using NameT = - typename detail::get_kernel_name_t::name; using LambdaArgType = sycl::detail::lambda_arg_type>; static_assert( @@ -1279,21 +1275,8 @@ class __SYCL_EXPORT handler { "must be either sycl::nd_item or be convertible from sycl::nd_item"); using TransformedArgType = sycl::nd_item; - (void)ExecutionRange; - (void)Props; - KernelWrapper::wrap(this, KernelFunc); -#ifndef __SYCL_DEVICE_ONLY__ - throwIfActionIsCreated(); - verifyUsedKernelBundleInternal( - detail::string_view{detail::getKernelName()}); - detail::checkValueRange(ExecutionRange); - setNDRangeDescriptor(std::move(ExecutionRange)); - processProperties(), PropertiesT>(Props); - StoreLambda( - std::move(KernelFunc)); - setType(detail::CGType::Kernel); -#endif + wrap_kernel( + KernelFunc, nullptr /*Kernel*/, Props, ExecutionRange); } /// Defines and invokes a SYCL kernel function for the specified range. @@ -1361,26 +1344,12 @@ class __SYCL_EXPORT handler { void parallel_for_work_group_lambda_impl(range NumWorkGroups, PropertiesT Props, const KernelType &KernelFunc) { - // TODO: Properties may change the kernel function, so in order to avoid - // conflicts they should be included in the name. - using NameT = - typename detail::get_kernel_name_t::name; using LambdaArgType = sycl::detail::lambda_arg_type>; - (void)NumWorkGroups; - (void)Props; - KernelWrapper::wrap(this, KernelFunc); -#ifndef __SYCL_DEVICE_ONLY__ - throwIfActionIsCreated(); - verifyUsedKernelBundleInternal( - detail::string_view{detail::getKernelName()}); - processProperties(), PropertiesT>(Props); - detail::checkValueRange(NumWorkGroups); - setNDRangeDescriptor(NumWorkGroups, /*SetNumWorkGroups=*/true); - StoreLambda(std::move(KernelFunc)); - setType(detail::CGType::Kernel); -#endif // __SYCL_DEVICE_ONLY__ + wrap_kernel(KernelFunc, nullptr /*Kernel*/, + Props, NumWorkGroups); } /// Hierarchical kernel invocation method of a kernel defined as a lambda @@ -1402,29 +1371,12 @@ class __SYCL_EXPORT handler { range WorkGroupSize, PropertiesT Props, const KernelType &KernelFunc) { - // TODO: Properties may change the kernel function, so in order to avoid - // conflicts they should be included in the name. - using NameT = - typename detail::get_kernel_name_t::name; using LambdaArgType = sycl::detail::lambda_arg_type>; - (void)NumWorkGroups; - (void)WorkGroupSize; - (void)Props; - KernelWrapper::wrap(this, KernelFunc); -#ifndef __SYCL_DEVICE_ONLY__ - throwIfActionIsCreated(); - verifyUsedKernelBundleInternal( - detail::string_view{detail::getKernelName()}); - processProperties(), PropertiesT>(Props); nd_range ExecRange = nd_range(NumWorkGroups * WorkGroupSize, WorkGroupSize); - detail::checkValueRange(ExecRange); - setNDRangeDescriptor(std::move(ExecRange)); - StoreLambda(std::move(KernelFunc)); - setType(detail::CGType::Kernel); -#endif // __SYCL_DEVICE_ONLY__ + wrap_kernel(KernelFunc, nullptr /*Kernel*/, Props, ExecRange); } #ifdef SYCL_LANGUAGE_VERSION @@ -1636,6 +1588,59 @@ class __SYCL_EXPORT handler { } }; + template < + WrapAs WrapAsVal, typename KernelName, typename ElementType = void, + int Dims = 1, bool SetNumWorkGroups = false, + typename PropertiesT = ext::oneapi::experimental::empty_properties_t, + typename KernelType, typename MaybeKernelTy, typename... RangeParams> + void wrap_kernel(const KernelType &KernelFunc, MaybeKernelTy &&MaybeKernel, + const PropertiesT &Props, + [[maybe_unused]] RangeParams &&...params) { + // TODO: Properties may change the kernel function, so in order to avoid + // conflicts they should be included in the name. + using NameT = + typename detail::get_kernel_name_t::name; + (void)Props; + (void)MaybeKernel; + static_assert(std::is_same_v || + std::is_same_v); + KernelWrapper::wrap( + this, KernelFunc); +#ifndef __SYCL_DEVICE_ONLY__ + throwIfActionIsCreated(); + if constexpr (std::is_same_v) { + // Ignore any set kernel bundles and use the one associated with the + // kernel. + setHandlerKernelBundle(MaybeKernel); + } + verifyUsedKernelBundleInternal( + detail::string_view{detail::getKernelName()}); + setType(detail::CGType::Kernel); + + detail::checkValueRange(params...); + if constexpr (SetNumWorkGroups) { + setNDRangeDescriptor(std::move(params)..., + /*SetNumWorkGroups=*/true); + } else { + setNDRangeDescriptor(std::move(params)...); + } + + if constexpr (std::is_same_v) { + StoreLambda(std::move(KernelFunc)); + } else { + MKernel = detail::getSyclObjImpl(std::move(MaybeKernel)); + if (!lambdaAndKernelHaveEqualName()) { + extractArgsAndReqs(); + MKernelName = getKernelName(); + } else { + StoreLambda( + std::move(KernelFunc)); + } + } + processProperties(), PropertiesT>(Props); +#endif + } + // NOTE: to support kernel_handler argument in kernel lambdas, only // KernelWrapper<...>::wrap() must be called in this code. @@ -1651,25 +1656,10 @@ class __SYCL_EXPORT handler { typename PropertiesT = ext::oneapi::experimental::empty_properties_t> void single_task_lambda_impl(PropertiesT Props, const KernelType &KernelFunc) { - (void)Props; - // TODO: Properties may change the kernel function, so in order to avoid - // conflicts they should be included in the name. - using NameT = - typename detail::get_kernel_name_t::name; - - KernelWrapper::wrap(this, KernelFunc); + wrap_kernel(KernelFunc, nullptr /*Kernel*/, + Props, range<1>{1}); #ifndef __SYCL_DEVICE_ONLY__ - throwIfActionIsCreated(); throwOnKernelParameterMisuse(); - verifyUsedKernelBundleInternal( - detail::string_view{detail::getKernelName()}); - // No need to check if range is out of INT_MAX limits as it's compile-time - // known constant. - setNDRangeDescriptor(range<1>{1}); - processProperties(), PropertiesT>(Props); - StoreLambda(KernelFunc); - setType(detail::CGType::Kernel); #endif } @@ -1953,26 +1943,13 @@ class __SYCL_EXPORT handler { __SYCL2020_DEPRECATED("offsets are deprecated in SYCL2020") void parallel_for(range NumWorkItems, id WorkItemOffset, const KernelType &KernelFunc) { - using NameT = - typename detail::get_kernel_name_t::name; using LambdaArgType = sycl::detail::lambda_arg_type>; using TransformedArgType = std::conditional_t< std::is_integral::value && Dims == 1, item, typename TransformUserItemType::type>; - (void)NumWorkItems; - (void)WorkItemOffset; - KernelWrapper::wrap(this, KernelFunc); -#ifndef __SYCL_DEVICE_ONLY__ - throwIfActionIsCreated(); - verifyUsedKernelBundleInternal( - detail::string_view{detail::getKernelName()}); - detail::checkValueRange(NumWorkItems, WorkItemOffset); - setNDRangeDescriptor(std::move(NumWorkItems), std::move(WorkItemOffset)); - StoreLambda( - std::move(KernelFunc)); - setType(detail::CGType::Kernel); -#endif + wrap_kernel( + KernelFunc, nullptr /*Kernel*/, {} /*Props*/, NumWorkItems, + WorkItemOffset); } /// Hierarchical kernel invocation method of a kernel defined as a lambda @@ -2133,28 +2110,9 @@ class __SYCL_EXPORT handler { const KernelType &KernelFunc) { // Ignore any set kernel bundles and use the one associated with the kernel setHandlerKernelBundle(Kernel); - using NameT = - typename detail::get_kernel_name_t::name; using LambdaArgType = sycl::detail::lambda_arg_type>; - (void)Kernel; - (void)NumWorkItems; - KernelWrapper::wrap( - this, KernelFunc); -#ifndef __SYCL_DEVICE_ONLY__ - throwIfActionIsCreated(); - verifyUsedKernelBundleInternal( - detail::string_view{detail::getKernelName()}); - detail::checkValueRange(NumWorkItems); - setNDRangeDescriptor(std::move(NumWorkItems)); - MKernel = detail::getSyclObjImpl(std::move(Kernel)); - setType(detail::CGType::Kernel); - if (!lambdaAndKernelHaveEqualName()) { - extractArgsAndReqs(); - MKernelName = getKernelName(); - } else - StoreLambda( - std::move(KernelFunc)); -#endif + wrap_kernel( + KernelFunc, Kernel, {} /*Props*/, NumWorkItems); } /// Defines and invokes a SYCL kernel function for the specified range and @@ -2171,31 +2129,9 @@ class __SYCL_EXPORT handler { __SYCL2020_DEPRECATED("offsets are deprecated in SYCL 2020") void parallel_for(kernel Kernel, range NumWorkItems, id WorkItemOffset, const KernelType &KernelFunc) { - using NameT = - typename detail::get_kernel_name_t::name; using LambdaArgType = sycl::detail::lambda_arg_type>; - (void)Kernel; - (void)NumWorkItems; - (void)WorkItemOffset; - KernelWrapper::wrap( - this, KernelFunc); -#ifndef __SYCL_DEVICE_ONLY__ - throwIfActionIsCreated(); - // Ignore any set kernel bundles and use the one associated with the kernel - setHandlerKernelBundle(Kernel); - verifyUsedKernelBundleInternal( - detail::string_view{detail::getKernelName()}); - detail::checkValueRange(NumWorkItems, WorkItemOffset); - setNDRangeDescriptor(std::move(NumWorkItems), std::move(WorkItemOffset)); - MKernel = detail::getSyclObjImpl(std::move(Kernel)); - setType(detail::CGType::Kernel); - if (!lambdaAndKernelHaveEqualName()) { - extractArgsAndReqs(); - MKernelName = getKernelName(); - } else - StoreLambda( - std::move(KernelFunc)); -#endif + wrap_kernel( + KernelFunc, Kernel, {} /*Props*/, NumWorkItems, WorkItemOffset); } /// Defines and invokes a SYCL kernel function for the specified range and @@ -2211,31 +2147,10 @@ class __SYCL_EXPORT handler { int Dims> void parallel_for(kernel Kernel, nd_range NDRange, const KernelType &KernelFunc) { - using NameT = - typename detail::get_kernel_name_t::name; using LambdaArgType = sycl::detail::lambda_arg_type>; - (void)Kernel; - (void)NDRange; - KernelWrapper::wrap( - this, KernelFunc); -#ifndef __SYCL_DEVICE_ONLY__ - throwIfActionIsCreated(); - // Ignore any set kernel bundles and use the one associated with the kernel - setHandlerKernelBundle(Kernel); - verifyUsedKernelBundleInternal( - detail::string_view{detail::getKernelName()}); - detail::checkValueRange(NDRange); - setNDRangeDescriptor(std::move(NDRange)); - MKernel = detail::getSyclObjImpl(std::move(Kernel)); - setType(detail::CGType::Kernel); - if (!lambdaAndKernelHaveEqualName()) { - extractArgsAndReqs(); - MKernelName = getKernelName(); - } else - StoreLambda( - std::move(KernelFunc)); -#endif + wrap_kernel( + KernelFunc, Kernel, {} /*Props*/, NDRange); } /// Hierarchical kernel invocation method of a kernel. @@ -2255,26 +2170,12 @@ class __SYCL_EXPORT handler { int Dims> void parallel_for_work_group(kernel Kernel, range NumWorkGroups, const KernelType &KernelFunc) { - using NameT = - typename detail::get_kernel_name_t::name; using LambdaArgType = sycl::detail::lambda_arg_type>; - (void)Kernel; - (void)NumWorkGroups; - KernelWrapper::wrap(this, KernelFunc); -#ifndef __SYCL_DEVICE_ONLY__ - throwIfActionIsCreated(); - // Ignore any set kernel bundles and use the one associated with the kernel - setHandlerKernelBundle(Kernel); - verifyUsedKernelBundleInternal( - detail::string_view{detail::getKernelName()}); - detail::checkValueRange(NumWorkGroups); - setNDRangeDescriptor(NumWorkGroups, /*SetNumWorkGroups=*/true); - MKernel = detail::getSyclObjImpl(std::move(Kernel)); - StoreLambda(std::move(KernelFunc)); - setType(detail::CGType::Kernel); -#endif // __SYCL_DEVICE_ONLY__ + wrap_kernel(KernelFunc, Kernel, {} /*Props*/, + NumWorkGroups); } /// Hierarchical kernel invocation method of a kernel. @@ -2297,29 +2198,12 @@ class __SYCL_EXPORT handler { void parallel_for_work_group(kernel Kernel, range NumWorkGroups, range WorkGroupSize, const KernelType &KernelFunc) { - using NameT = - typename detail::get_kernel_name_t::name; using LambdaArgType = sycl::detail::lambda_arg_type>; - (void)Kernel; - (void)NumWorkGroups; - (void)WorkGroupSize; - KernelWrapper::wrap(this, KernelFunc); -#ifndef __SYCL_DEVICE_ONLY__ - throwIfActionIsCreated(); - // Ignore any set kernel bundles and use the one associated with the kernel - setHandlerKernelBundle(Kernel); - verifyUsedKernelBundleInternal( - detail::string_view{detail::getKernelName()}); nd_range ExecRange = nd_range(NumWorkGroups * WorkGroupSize, WorkGroupSize); - detail::checkValueRange(ExecRange); - setNDRangeDescriptor(std::move(ExecRange)); - MKernel = detail::getSyclObjImpl(std::move(Kernel)); - StoreLambda(std::move(KernelFunc)); - setType(detail::CGType::Kernel); -#endif // __SYCL_DEVICE_ONLY__ + wrap_kernel(KernelFunc, Kernel, {} /*Props*/, ExecRange); } template