diff --git a/sycl/include/sycl/handler.hpp b/sycl/include/sycl/handler.hpp index 1ecaecd1d0c98..8a62a602a1fb2 100644 --- a/sycl/include/sycl/handler.hpp +++ b/sycl/include/sycl/handler.hpp @@ -1337,57 +1337,6 @@ class __SYCL_EXPORT handler { #endif } - /// Hierarchical kernel invocation method of a kernel defined as a lambda - /// encoding the body of each work-group to launch. - /// - /// Lambda may contain multiple calls to parallel_for_work_item(...) methods - /// representing the execution on each work-item. Launches NumWorkGroups - /// work-groups of runtime-defined size. - /// - /// \param NumWorkGroups is a range describing the number of work-groups in - /// each dimension. - /// \param KernelFunc is a lambda representing kernel. - template < - typename KernelName, typename KernelType, int Dims, - typename PropertiesT = ext::oneapi::experimental::empty_properties_t> - void parallel_for_work_group_lambda_impl(range NumWorkGroups, - PropertiesT Props, - const KernelType &KernelFunc) { - using LambdaArgType = - sycl::detail::lambda_arg_type>; - wrap_kernel(KernelFunc, nullptr /*Kernel*/, - Props, NumWorkGroups); - } - - /// Hierarchical kernel invocation method of a kernel defined as a lambda - /// encoding the body of each work-group to launch. - /// - /// Lambda may contain multiple calls to parallel_for_work_item(...) methods - /// representing the execution on each work-item. Launches NumWorkGroups - /// work-groups of WorkGroupSize size. - /// - /// \param NumWorkGroups is a range describing the number of work-groups in - /// each dimension. - /// \param WorkGroupSize is a range describing the size of work-groups in - /// each dimension. - /// \param KernelFunc is a lambda representing kernel. - template < - typename KernelName, typename KernelType, int Dims, - typename PropertiesT = ext::oneapi::experimental::empty_properties_t> - void parallel_for_work_group_lambda_impl(range NumWorkGroups, - range WorkGroupSize, - PropertiesT Props, - const KernelType &KernelFunc) { - using LambdaArgType = - sycl::detail::lambda_arg_type>; - nd_range ExecRange = - nd_range(NumWorkGroups * WorkGroupSize, WorkGroupSize); - wrap_kernel(KernelFunc, nullptr /*Kernel*/, Props, ExecRange); - } - #ifdef SYCL_LANGUAGE_VERSION #ifndef __INTEL_SYCL_USE_INTEGRATION_HEADERS #define __SYCL_KERNEL_ATTR__ [[clang::sycl_kernel_entry_point(KernelName)]] @@ -1598,13 +1547,17 @@ class __SYCL_EXPORT handler { }; template < - WrapAs WrapAsVal, typename KernelName, typename ElementType = void, + WrapAs WrapAsVal, typename KernelName, typename ElementTypeParam = void, int Dims = 1, bool SetNumWorkGroups = false, typename PropertiesT = ext::oneapi::experimental::empty_properties_t, typename KernelType, typename MaybeKernelTy, typename... RangeParams> void wrap_kernel(const KernelType &KernelFunc, MaybeKernelTy &&MaybeKernel, const PropertiesT &Props, [[maybe_unused]] RangeParams &&...params) { + using ElementType = std::conditional_t< + WrapAsVal == WrapAs::parallel_for_work_group, + sycl::detail::lambda_arg_type>, + ElementTypeParam>; // TODO: Properties may change the kernel function, so in order to avoid // conflicts they should be included in the name. using NameT = @@ -1959,9 +1912,10 @@ class __SYCL_EXPORT handler { int Dims> void parallel_for_work_group(range NumWorkGroups, const KernelType &KernelFunc) { - parallel_for_work_group_lambda_impl( - NumWorkGroups, ext::oneapi::experimental::empty_properties_t{}, - KernelFunc); + wrap_kernel(KernelFunc, nullptr /*Kernel*/, + {} /*Props*/, NumWorkGroups); } /// Hierarchical kernel invocation method of a kernel defined as a lambda @@ -1981,9 +1935,10 @@ class __SYCL_EXPORT handler { void parallel_for_work_group(range NumWorkGroups, range WorkGroupSize, const KernelType &KernelFunc) { - parallel_for_work_group_lambda_impl( - NumWorkGroups, WorkGroupSize, - ext::oneapi::experimental::empty_properties_t{}, KernelFunc); + wrap_kernel( + KernelFunc, nullptr /*Kernel*/, {} /*Props*/, + nd_range{NumWorkGroups * WorkGroupSize, WorkGroupSize}); } /// Invokes a SYCL kernel. @@ -2395,9 +2350,10 @@ class __SYCL_EXPORT handler { "member function instead.") void parallel_for_work_group(range NumWorkGroups, PropertiesT Props, const KernelType &KernelFunc) { - parallel_for_work_group_lambda_impl(NumWorkGroups, Props, - KernelFunc); + wrap_kernel(KernelFunc, nullptr /*Kernel*/, + Props, NumWorkGroups); } template NumWorkGroups, range WorkGroupSize, PropertiesT Props, const KernelType &KernelFunc) { - parallel_for_work_group_lambda_impl( - NumWorkGroups, WorkGroupSize, Props, KernelFunc); + wrap_kernel( + KernelFunc, nullptr /*Kernel*/, Props, + nd_range{NumWorkGroups * WorkGroupSize, WorkGroupSize}); } // Explicit copy operations API