Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 21 additions & 64 deletions sycl/include/sycl/handler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1337,57 +1337,6 @@ class __SYCL_EXPORT handler {
#endif
}

/// Hierarchical kernel invocation method of a kernel defined as a lambda
/// encoding the body of each work-group to launch.
///
/// Lambda may contain multiple calls to parallel_for_work_item(...) methods
/// representing the execution on each work-item. Launches NumWorkGroups
/// work-groups of runtime-defined size.
///
/// \param NumWorkGroups is a range describing the number of work-groups in
/// each dimension.
/// \param KernelFunc is a lambda representing kernel.
template <
typename KernelName, typename KernelType, int Dims,
typename PropertiesT = ext::oneapi::experimental::empty_properties_t>
void parallel_for_work_group_lambda_impl(range<Dims> NumWorkGroups,
PropertiesT Props,
const KernelType &KernelFunc) {
using LambdaArgType =
sycl::detail::lambda_arg_type<KernelType, group<Dims>>;
wrap_kernel<WrapAs::parallel_for_work_group, KernelName, LambdaArgType,
Dims,
/*SetNumWorkGroups=*/true>(KernelFunc, nullptr /*Kernel*/,
Props, NumWorkGroups);
}

/// Hierarchical kernel invocation method of a kernel defined as a lambda
/// encoding the body of each work-group to launch.
///
/// Lambda may contain multiple calls to parallel_for_work_item(...) methods
/// representing the execution on each work-item. Launches NumWorkGroups
/// work-groups of WorkGroupSize size.
///
/// \param NumWorkGroups is a range describing the number of work-groups in
/// each dimension.
/// \param WorkGroupSize is a range describing the size of work-groups in
/// each dimension.
/// \param KernelFunc is a lambda representing kernel.
template <
typename KernelName, typename KernelType, int Dims,
typename PropertiesT = ext::oneapi::experimental::empty_properties_t>
void parallel_for_work_group_lambda_impl(range<Dims> NumWorkGroups,
range<Dims> WorkGroupSize,
PropertiesT Props,
const KernelType &KernelFunc) {
using LambdaArgType =
sycl::detail::lambda_arg_type<KernelType, group<Dims>>;
nd_range<Dims> ExecRange =
nd_range<Dims>(NumWorkGroups * WorkGroupSize, WorkGroupSize);
wrap_kernel<WrapAs::parallel_for_work_group, KernelName, LambdaArgType,
Dims>(KernelFunc, nullptr /*Kernel*/, Props, ExecRange);
}

#ifdef SYCL_LANGUAGE_VERSION
#ifndef __INTEL_SYCL_USE_INTEGRATION_HEADERS
#define __SYCL_KERNEL_ATTR__ [[clang::sycl_kernel_entry_point(KernelName)]]
Expand Down Expand Up @@ -1598,13 +1547,17 @@ class __SYCL_EXPORT handler {
};

template <
WrapAs WrapAsVal, typename KernelName, typename ElementType = void,
WrapAs WrapAsVal, typename KernelName, typename ElementTypeParam = void,
int Dims = 1, bool SetNumWorkGroups = false,
typename PropertiesT = ext::oneapi::experimental::empty_properties_t,
typename KernelType, typename MaybeKernelTy, typename... RangeParams>
void wrap_kernel(const KernelType &KernelFunc, MaybeKernelTy &&MaybeKernel,
const PropertiesT &Props,
[[maybe_unused]] RangeParams &&...params) {
using ElementType = std::conditional_t<
WrapAsVal == WrapAs::parallel_for_work_group,
sycl::detail::lambda_arg_type<KernelType, group<Dims>>,
ElementTypeParam>;
// TODO: Properties may change the kernel function, so in order to avoid
// conflicts they should be included in the name.
using NameT =
Expand Down Expand Up @@ -1959,9 +1912,10 @@ class __SYCL_EXPORT handler {
int Dims>
void parallel_for_work_group(range<Dims> NumWorkGroups,
const KernelType &KernelFunc) {
parallel_for_work_group_lambda_impl<KernelName>(
NumWorkGroups, ext::oneapi::experimental::empty_properties_t{},
KernelFunc);
wrap_kernel<WrapAs::parallel_for_work_group, KernelName,
void /*auto-detect*/, Dims,
/*SetNumWorkGroups=*/true>(KernelFunc, nullptr /*Kernel*/,
{} /*Props*/, NumWorkGroups);
}

/// Hierarchical kernel invocation method of a kernel defined as a lambda
Expand All @@ -1981,9 +1935,10 @@ class __SYCL_EXPORT handler {
void parallel_for_work_group(range<Dims> NumWorkGroups,
range<Dims> WorkGroupSize,
const KernelType &KernelFunc) {
parallel_for_work_group_lambda_impl<KernelName>(
NumWorkGroups, WorkGroupSize,
ext::oneapi::experimental::empty_properties_t{}, KernelFunc);
wrap_kernel<WrapAs::parallel_for_work_group, KernelName,
void /*auto-detect*/, Dims>(
KernelFunc, nullptr /*Kernel*/, {} /*Props*/,
nd_range<Dims>{NumWorkGroups * WorkGroupSize, WorkGroupSize});
}

/// Invokes a SYCL kernel.
Expand Down Expand Up @@ -2395,9 +2350,10 @@ class __SYCL_EXPORT handler {
"member function instead.")
void parallel_for_work_group(range<Dims> NumWorkGroups, PropertiesT Props,
const KernelType &KernelFunc) {
parallel_for_work_group_lambda_impl<KernelName, KernelType, Dims,
PropertiesT>(NumWorkGroups, Props,
KernelFunc);
wrap_kernel<WrapAs::parallel_for_work_group, KernelName,
void /*auto-detect*/, Dims,
/*SetNumWorkGroups=*/true>(KernelFunc, nullptr /*Kernel*/,
Props, NumWorkGroups);
}

template <typename KernelName = detail::auto_name, typename KernelType,
Expand All @@ -2409,9 +2365,10 @@ class __SYCL_EXPORT handler {
void parallel_for_work_group(range<Dims> NumWorkGroups,
range<Dims> WorkGroupSize, PropertiesT Props,
const KernelType &KernelFunc) {
parallel_for_work_group_lambda_impl<KernelName, KernelType, Dims,
PropertiesT>(
NumWorkGroups, WorkGroupSize, Props, KernelFunc);
wrap_kernel<WrapAs::parallel_for_work_group, KernelName,
void /*auto-detect*/, Dims>(
KernelFunc, nullptr /*Kernel*/, Props,
nd_range<Dims>{NumWorkGroups * WorkGroupSize, WorkGroupSize});
}

// Explicit copy operations API
Expand Down