diff --git a/sycl/include/sycl/ext/oneapi/experimental/reduction_properties.hpp b/sycl/include/sycl/ext/oneapi/experimental/reduction_properties.hpp index a96378c522f82..b4f2916a55bf1 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/reduction_properties.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/reduction_properties.hpp @@ -32,6 +32,20 @@ struct initialize_to_identity_key }; inline constexpr initialize_to_identity_key::value_t initialize_to_identity; +namespace detail { +struct reduction_property_check_anchor {}; +} // namespace detail + +template <> +struct is_property_key_of + : std::true_type {}; + +template <> +struct is_property_key_of + : std::true_type {}; + } // namespace experimental } // namespace oneapi } // namespace ext @@ -83,60 +97,88 @@ template struct IsDeterministicOperator> : std::true_type {}; +template +inline constexpr bool is_valid_reduction_prop_list = + ext::oneapi::experimental::detail::all_are_properties_of_v< + ext::oneapi::experimental::detail::reduction_property_check_anchor, + PropertyList>; + +template +auto convert_reduction_properties(BinaryOperation combiner, + PropertyList properties, Args &&...args) { + if constexpr (is_valid_reduction_prop_list) { + auto WrappedOp = WrapOp(combiner, properties); + auto RuntimeProps = GetReductionPropertyList(properties); + return sycl::reduction(std::forward(args)..., WrappedOp, + RuntimeProps); + } else { + // Invalid, will be disabled by SFINAE at the caller side. Make sure no hard + // error is emitted from here. + } +} } // namespace detail template auto reduction(BufferT vars, handler &cgh, BinaryOperation combiner, - PropertyList properties) { + PropertyList properties) + -> std::enable_if_t, + decltype(detail::convert_reduction_properties( + combiner, properties, vars, cgh))> { detail::CheckReductionIdentity( properties); - auto WrappedOp = detail::WrapOp(combiner, properties); - auto RuntimeProps = detail::GetReductionPropertyList(properties); - return reduction(vars, cgh, WrappedOp, RuntimeProps); + return detail::convert_reduction_properties(combiner, properties, vars, cgh); } template -auto reduction(T *var, BinaryOperation combiner, PropertyList properties) { +auto reduction(T *var, BinaryOperation combiner, PropertyList properties) + -> std::enable_if_t, + decltype(detail::convert_reduction_properties( + combiner, properties, var))> { detail::CheckReductionIdentity(properties); - auto WrappedOp = detail::WrapOp(combiner, properties); - auto RuntimeProps = detail::GetReductionPropertyList(properties); - return reduction(var, WrappedOp, RuntimeProps); + return detail::convert_reduction_properties(combiner, properties, var); } template auto reduction(span vars, BinaryOperation combiner, - PropertyList properties) { + PropertyList properties) + -> std::enable_if_t, + decltype(detail::convert_reduction_properties( + combiner, properties, vars))> { detail::CheckReductionIdentity(properties); - auto WrappedOp = detail::WrapOp(combiner, properties); - auto RuntimeProps = detail::GetReductionPropertyList(properties); - return reduction(vars, WrappedOp, RuntimeProps); + return detail::convert_reduction_properties(combiner, properties, vars); } template auto reduction(BufferT vars, handler &cgh, const typename BufferT::value_type &identity, - BinaryOperation combiner, PropertyList properties) { - auto WrappedOp = detail::WrapOp(combiner, properties); - auto RuntimeProps = detail::GetReductionPropertyList(properties); - return reduction(vars, cgh, identity, WrappedOp, RuntimeProps); + BinaryOperation combiner, PropertyList properties) + -> std::enable_if_t, + decltype(detail::convert_reduction_properties( + combiner, properties, vars, cgh, identity))> { + return detail::convert_reduction_properties(combiner, properties, vars, cgh, + identity); } template auto reduction(T *var, const T &identity, BinaryOperation combiner, - PropertyList properties) { - auto WrappedOp = detail::WrapOp(combiner, properties); - auto RuntimeProps = detail::GetReductionPropertyList(properties); - return reduction(var, identity, WrappedOp, RuntimeProps); + PropertyList properties) + -> std::enable_if_t, + decltype(detail::convert_reduction_properties( + combiner, properties, var, identity))> { + return detail::convert_reduction_properties(combiner, properties, var, + identity); } template auto reduction(span vars, const T &identity, - BinaryOperation combiner, PropertyList properties) { - auto WrappedOp = detail::WrapOp(combiner, properties); - auto RuntimeProps = detail::GetReductionPropertyList(properties); - return reduction(vars, identity, WrappedOp, RuntimeProps); + BinaryOperation combiner, PropertyList properties) + -> std::enable_if_t, + decltype(detail::convert_reduction_properties( + combiner, properties, vars, identity))> { + return detail::convert_reduction_properties(combiner, properties, vars, + identity); } } // namespace _V1 diff --git a/sycl/test/extensions/properties/properties_reduction.cpp b/sycl/test/extensions/properties/properties_reduction.cpp new file mode 100644 index 0000000000000..eb94010a1e523 --- /dev/null +++ b/sycl/test/extensions/properties/properties_reduction.cpp @@ -0,0 +1,28 @@ +// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple -fsyntax-only -Xclang -verify -Xclang -verify-ignore-unexpected=note %s + +#include + +int main() { + int *r = nullptr; + // Must not use `sycl_ext_oneapi_reduction_properties`'s overloads: + std::ignore = + sycl::reduction(r, sycl::plus{}, + sycl::property::reduction::initialize_to_identity{}); + + namespace sycl_exp = sycl::ext::oneapi::experimental; + std::ignore = + sycl::reduction(r, sycl::plus{}, + sycl_exp::properties(sycl_exp::initialize_to_identity)); + + // Not a property list: + // expected-error@+2 {{no matching function for call to 'reduction'}} + std::ignore = + sycl::reduction(r, sycl::plus{}, sycl_exp::initialize_to_identity); + + // Not a reduction property: + // expected-error@+2 {{no matching function for call to 'reduction'}} + std::ignore = + sycl::reduction(r, sycl::plus{}, + sycl_exp::properties(sycl_exp::initialize_to_identity, + sycl_exp::full_group)); +}