diff --git a/sycl/doc/syclcompat/README.md b/sycl/doc/syclcompat/README.md index df8d453a41d57..c3a263d2d072a 100644 --- a/sycl/doc/syclcompat/README.md +++ b/sycl/doc/syclcompat/README.md @@ -1725,7 +1725,51 @@ second operand, respectively. These three APIs return a single 32-bit value with the accumulated result, which is unsigned if both operands are `uint32_t` and signed otherwise. +Various maths functions are defined operate on any floating point types. +`syclcompat::is_floating_point_v` extends the standard library's +`std::is_floating_point_v` to include `sycl::half` and, where available, +`sycl::ext::oneapi::bfloat16`. The current version of SYCLcompat also provides +a specialization of `std::common_type_t` for `sycl::ext::oneapi::bfloat16`, +though this will be moved to the `sycl_ext_oneapi_bfloat16` extension in +future. + +```cpp +namespace std { +template <> struct common_type { + using type = sycl::ext::oneapi::bfloat16; +}; + +template <> +struct common_type { + using type = sycl::ext::oneapi::bfloat16; +}; + +template struct common_type { + using type = sycl::ext::oneapi::bfloat16; +}; + +template struct common_type { + using type = sycl::ext::oneapi::bfloat16; +}; +} // namespace std +``` + ```cpp +namespace syclcompat{ + +// Trait for extended floating point definition +template +struct is_floating_point : std::is_floating_point{}; + +template <> struct is_floating_point : std::true_type {}; + +#ifdef SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS +template <> struct is_floating_point : std::true_type {}; +#endif +template + +inline constexpr bool is_floating_point_v = is_floating_point::value; + inline unsigned int funnelshift_l(unsigned int low, unsigned int high, unsigned int shift); @@ -1752,11 +1796,9 @@ inline std::enable_if_t isnan(const ValueT a); // cbrt function wrapper. template inline std::enable_if_t || - std::is_same_v, + std::is_same_v, ValueT> -cbrt(ValueT val) { - return sycl::cbrt(static_cast(val)); -} +cbrt(ValueT val); // For floating-point types, `float` or `double` arguments are acceptable. // For integer types, `std::uint32_t`, `std::int32_t`, `std::uint64_t` or @@ -1794,6 +1836,10 @@ template inline sycl::vec, 2> fmax_nan(const sycl::vec a, const sycl::vec b); +template +inline sycl::marray, 2> +fmax_nan(const sycl::marray a, const sycl::marray b); + // Performs 2 elements comparison and returns the smaller one. If either of // inputs is NaN, then return NaN. template @@ -1803,6 +1849,10 @@ template inline sycl::vec, 2> fmin_nan(const sycl::vec a, const sycl::vec b); +template +inline sycl::marray, 2> +fmin_nan(const sycl::marray a, const sycl::marray b); + inline float pow(const float a, const int b) { return sycl::pown(a, b); } inline double pow(const double a, const int b) { return sycl::pown(a, b); } @@ -1863,14 +1913,13 @@ unordered_compare_both(const ValueT a, const ValueT b, const BinaryOperation binary_op); template -inline unsigned compare_mask(const sycl::vec a, - const sycl::vec b, - const BinaryOperation binary_op); +inline std::enable_if_t +compare_mask(const ValueT a, const ValueT b, const BinaryOperation binary_op); template -inline unsigned unordered_compare_mask(const sycl::vec a, - const sycl::vec b, - const BinaryOperation binary_op); +inline std::enable_if_t +unordered_compare_mask(const ValueT a, const ValueT b, + const BinaryOperation binary_op); template inline T vectorized_max(T a, T b); @@ -1924,6 +1973,7 @@ inline dot_product_acc_t dp2a_hi(T1 a, T2 b, template inline dot_product_acc_t dp4a(T1 a, T2 b, dot_product_acc_t c); +} // namespace syclcompat ``` `vectorized_binary` computes the `BinaryOperation` for two operands, diff --git a/sycl/include/syclcompat/math.hpp b/sycl/include/syclcompat/math.hpp index 785d95f6f2404..a3ee2b2085788 100644 --- a/sycl/include/syclcompat/math.hpp +++ b/sycl/include/syclcompat/math.hpp @@ -31,12 +31,19 @@ #pragma once +#include +#include + +// TODO(syclcompat-lib-reviewers): this should not be required #ifndef SYCL_EXT_ONEAPI_COMPLEX #define SYCL_EXT_ONEAPI_COMPLEX #endif +#ifdef SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS #include +#endif #include +#include namespace syclcompat { namespace detail { @@ -46,18 +53,25 @@ namespace complex_namespace = sycl::ext::oneapi::experimental; template using complex_type = detail::complex_namespace::complex; +template +constexpr bool is_int32_type = std::is_same_v, int32_t> || + std::is_same_v, uint32_t>; + +// Helper constexpr bool to avoid ugly macros where possible +#ifdef SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS +constexpr bool support_bfloat16_math = true; +#else +constexpr bool support_bfloat16_math = false; +#endif + template inline ValueT clamp(ValueT val, ValueT min_val, ValueT max_val) { return sycl::clamp(val, min_val, max_val); } - -template -constexpr bool is_int32_type = std::is_same_v, int32_t> || - std::is_same_v, uint32_t>; - #ifdef SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS -// TODO: Follow the process to add this to the extension. If added, -// remove this functionality from the header. +// TODO(syclcompat-lib-reviewers): Follow the process to add this (& other math +// fns) to the bfloat16 math function extension. If added, remove this +// functionality from the header. template <> inline sycl::ext::oneapi::bfloat16 clamp(sycl::ext::oneapi::bfloat16 val, sycl::ext::oneapi::bfloat16 min_val, @@ -68,6 +82,28 @@ inline sycl::ext::oneapi::bfloat16 clamp(sycl::ext::oneapi::bfloat16 val, return max_val; return val; } + +template +inline std::enable_if_t, + sycl::vec> +clamp(sycl::vec val, sycl::vec min_val, + sycl::vec max_val) { + return [&val, &min_val, &max_val](std::integer_sequence) { + return sycl::vec{ + clamp(val[I], min_val[I], max_val[I])...}; + }(std::make_integer_sequence{}); +} + +template +inline std::enable_if_t, + sycl::marray> +clamp(sycl::marray val, sycl::marray min_val, + sycl::marray max_val) { + return [&val, &min_val, &max_val](std::index_sequence) { + return sycl::marray{ + clamp(val[I], min_val[I], max_val[I])...}; + }(std::make_index_sequence{}); +} #endif template @@ -218,13 +254,13 @@ inline constexpr RetT extend_vbinary4(AT a, BT b, RetT c, } template inline bool isnan(const ValueT a) { - return sycl::isnan(a); -} -#ifdef SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS -inline bool isnan(const sycl::ext::oneapi::bfloat16 a) { - return sycl::ext::oneapi::experimental::isnan(a); + if constexpr (std::is_same_v) { + static_assert(detail::support_bfloat16_math); + return sycl::ext::oneapi::experimental::isnan(a); + } else { + return sycl::isnan(a); + } } -#endif // FIXME(syclcompat-lib-reviewers): move bfe outside detail once perf is // improved & semantics understood @@ -543,9 +579,8 @@ unordered_compare_both(const ValueT a, const ValueT b, /// \param [in] binary_op functor that implements the binary operation /// \returns the comparison result template -inline unsigned compare_mask(const sycl::vec a, - const sycl::vec b, - const BinaryOperation binary_op) { +inline std::enable_if_t +compare_mask(const ValueT a, const ValueT b, const BinaryOperation binary_op) { // Since compare returns 0 or 1, -compare will be 0x00000000 or 0xFFFFFFFF return ((-compare(a[0], b[0], binary_op)) << 16) | ((-compare(a[1], b[1], binary_op)) & 0xFFFF); @@ -559,9 +594,9 @@ inline unsigned compare_mask(const sycl::vec a, /// \param [in] binary_op functor that implements the binary operation /// \returns the comparison result template -inline unsigned unordered_compare_mask(const sycl::vec a, - const sycl::vec b, - const BinaryOperation binary_op) { +inline std::enable_if_t +unordered_compare_mask(const ValueT a, const ValueT b, + const BinaryOperation binary_op) { return ((-unordered_compare(a[0], b[0], binary_op)) << 16) | ((-unordered_compare(a[1], b[1], binary_op)) & 0xFFFF); } @@ -687,7 +722,7 @@ inline std::enable_if_t isnan(const ValueT a) { /// cbrt function wrapper. template inline std::enable_if_t || - std::is_same_v, + std::is_same_v, ValueT> cbrt(ValueT val) { return sycl::cbrt(static_cast(val)); @@ -697,7 +732,7 @@ cbrt(ValueT val) { // For floating-point types, `float` or `double` arguments are acceptable. // For integer types, `std::uint32_t`, `std::int32_t`, `std::uint64_t` or // `std::int64_t` type arguments are acceptable. -// sycl::half supported as well. +// sycl::half supported as well, and sycl::ext::oneapi::bfloat16 if available. template inline std::enable_if_t && std::is_integral_v, @@ -706,15 +741,23 @@ min(ValueT a, ValueU b) { return sycl::min(static_cast>(a), static_cast>(b)); } + template -inline std::enable_if_t && - std::is_floating_point_v, +inline std::enable_if_t && + syclcompat::is_floating_point_v, std::common_type_t> min(ValueT a, ValueU b) { - return sycl::fmin(static_cast>(a), - static_cast>(b)); + if constexpr (std::is_same_v, + sycl::ext::oneapi::bfloat16>) { + static_assert(detail::support_bfloat16_math); + return sycl::ext::oneapi::experimental::fmin( + static_cast>(a), + static_cast>(b)); + } else { + return sycl::fmin(static_cast>(a), + static_cast>(b)); + } } -inline sycl::half min(sycl::half a, sycl::half b) { return sycl::fmin(a, b); } template inline std::enable_if_t && @@ -725,14 +768,21 @@ max(ValueT a, ValueU b) { static_cast>(b)); } template -inline std::enable_if_t && - std::is_floating_point_v, +inline std::enable_if_t && + syclcompat::is_floating_point_v, std::common_type_t> max(ValueT a, ValueU b) { - return sycl::fmax(static_cast>(a), - static_cast>(b)); + if constexpr (std::is_same_v, + sycl::ext::oneapi::bfloat16>) { + static_assert(detail::support_bfloat16_math); + return sycl::ext::oneapi::experimental::fmax( + static_cast>(a), + static_cast>(b)); + } else { + return sycl::fmax(static_cast>(a), + static_cast>(b)); + } } -inline sycl::half max(sycl::half a, sycl::half b) { return sycl::fmax(a, b); } /// Performs 2 elements comparison and returns the bigger one. If either of /// inputs is NaN, then return NaN. @@ -744,15 +794,21 @@ inline std::common_type_t fmax_nan(const ValueT a, const ValueU b) { if (detail::isnan(a) || detail::isnan(b)) return NAN; - return sycl::fmax(static_cast>(a), - static_cast>(b)); + return syclcompat::max(a, b); } + template inline sycl::vec, 2> fmax_nan(const sycl::vec a, const sycl::vec b) { return {fmax_nan(a[0], b[0]), fmax_nan(a[1], b[1])}; } +template +inline sycl::marray, 2> +fmax_nan(const sycl::marray a, const sycl::marray b) { + return {fmax_nan(a[0], b[0]), fmax_nan(a[1], b[1])}; +} + /// Performs 2 elements comparison and returns the smaller one. If either of /// inputs is NaN, then return NaN. /// \param [in] a The first value @@ -763,15 +819,21 @@ inline std::common_type_t fmin_nan(const ValueT a, const ValueU b) { if (detail::isnan(a) || detail::isnan(b)) return NAN; - return sycl::fmin(static_cast>(a), - static_cast>(b)); + return syclcompat::min(a,b); } + template inline sycl::vec, 2> fmin_nan(const sycl::vec a, const sycl::vec b) { return {fmin_nan(a[0], b[0]), fmin_nan(a[1], b[1])}; } +template +inline sycl::marray, 2> +fmin_nan(const sycl::marray a, const sycl::marray b) { + return {fmin_nan(a[0], b[0]), fmin_nan(a[1], b[1])}; +} + // pow functions overload. inline float pow(const float a, const int b) { return sycl::pown(a, b); } inline double pow(const double a, const int b) { return sycl::pown(a, b); } @@ -781,10 +843,10 @@ inline typename std::enable_if_t, ValueT> pow(const ValueT a, const ValueU b) { return sycl::pow(a, static_cast(b)); } - -// TODO: calling pow with non-floating point values is currently defaulting to -// double, which fails on devices without aspect::fp64. This has to be properly -// documented, and maybe changed to support all devices. +// TODO(syclcompat-lib-reviewers) calling pow with non-floating point values +// is currently defaulting to double, which fails on devices without +// aspect::fp64. This has to be properly documented, and maybe changed to +// support all devices. template inline typename std::enable_if_t, double> pow(const ValueT a, const ValueU b) { @@ -795,24 +857,20 @@ pow(const ValueT a, const ValueU b) { /// \param [in] a The input value /// \returns the relu saturation result template -inline std::enable_if_t || - std::is_same_v, - ValueT> +inline std::enable_if_t, ValueT> relu(const ValueT a) { if (!detail::isnan(a) && a < ValueT(0)) return ValueT(0); return a; } template -inline std::enable_if_t || - std::is_same_v, +inline std::enable_if_t, sycl::vec> relu(const sycl::vec a) { return {relu(a[0]), relu(a[1])}; } template -inline std::enable_if_t || - std::is_same_v, +inline std::enable_if_t, sycl::marray> relu(const sycl::marray a) { return {relu(a[0]), relu(a[1])}; diff --git a/sycl/include/syclcompat/traits.hpp b/sycl/include/syclcompat/traits.hpp index 2f389ccf79484..7ed4d765251bc 100644 --- a/sycl/include/syclcompat/traits.hpp +++ b/sycl/include/syclcompat/traits.hpp @@ -22,6 +22,10 @@ #pragma once +#include +#ifdef SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS +#include +#endif #include #include #include @@ -250,4 +254,40 @@ using are_all_props = std::conjunction< } // namespace experimental::detail +// Trait for extended floating point definition +template +struct is_floating_point : std::is_floating_point{}; + +template <> struct is_floating_point : std::true_type {}; + +#ifdef SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS +template <> struct is_floating_point : std::true_type {}; +#endif + +template +inline constexpr bool is_floating_point_v = is_floating_point::value; + } // namespace syclcompat + +// Specialize std::common_type for bfloat16 +// Semantics here match bfloat16.hpp operator overloads (all mixed type math +// ops return bfloat16) +// TODO(syclcompat-lib-reviewers) Move this to bfloat extension +namespace std { +template <> struct common_type { + using type = sycl::ext::oneapi::bfloat16; +}; + +template <> +struct common_type { + using type = sycl::ext::oneapi::bfloat16; +}; + +template struct common_type { + using type = sycl::ext::oneapi::bfloat16; +}; + +template struct common_type { + using type = sycl::ext::oneapi::bfloat16; +}; +} // namespace std diff --git a/sycl/test-e2e/syclcompat/common.hpp b/sycl/test-e2e/syclcompat/common.hpp index 368089e89e85a..ff840c98209bd 100644 --- a/sycl/test-e2e/syclcompat/common.hpp +++ b/sycl/test-e2e/syclcompat/common.hpp @@ -22,6 +22,10 @@ #pragma once +#include +#ifdef SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS +#include +#endif #include #include @@ -44,8 +48,42 @@ template void instantiate_all_types(Func &&f) { f(); \ }); +#define INSTANTIATE_ALL_CONTAINER_TYPES(tuple, container, f) \ + instantiate_all_types([](auto index) { \ + using T = std::tuple_element_t; \ + f(); \ + }); + using value_type_list = - std::tuple; + std::tuple; + +using fp_type_list_no_bfloat16 = std::tuple; + +using fp_type_list = std::tuple; -using fp_type_list = std::tuple; +using marray_type_list = + std::tuple; +using vec_type_list = std::tuple; diff --git a/sycl/test-e2e/syclcompat/launch/launch_policy_lmem.cpp b/sycl/test-e2e/syclcompat/launch/launch_policy_lmem.cpp index a22d54474d9ed..41f9a8cbee747 100644 --- a/sycl/test-e2e/syclcompat/launch/launch_policy_lmem.cpp +++ b/sycl/test-e2e/syclcompat/launch/launch_policy_lmem.cpp @@ -58,14 +58,19 @@ void dynamic_local_mem_typed_kernel(T *data, char *local_mem) { constexpr size_t num_elements = memsize / sizeof(T); T *typed_local_mem = reinterpret_cast(local_mem); - const int id = - sycl::ext::oneapi::this_work_item::get_nd_item<3>().get_global_linear_id(); - if (id < num_elements) { - typed_local_mem[id] = static_cast(id); - } - syclcompat::wg_barrier(); - if (id < num_elements) { - data[id] = typed_local_mem[num_elements - id - 1]; + const int local_id = + sycl::ext::oneapi::this_work_item::get_nd_item<3>().get_local_linear_id(); + const int group_id = + sycl::ext::oneapi::this_work_item::get_nd_item<3>().get_group_linear_id(); + // Only operate in first work-group + if (group_id == 0) { + if (local_id < num_elements) { + typed_local_mem[local_id] = static_cast(local_id); + } + syclcompat::wg_barrier(); + if (local_id < num_elements) { + data[local_id] = typed_local_mem[num_elements - local_id - 1]; + } } }; diff --git a/sycl/test-e2e/syclcompat/math/math_compare.cpp b/sycl/test-e2e/syclcompat/math/math_compare.cpp index c42b22b199888..0f77160a564e7 100644 --- a/sycl/test-e2e/syclcompat/math/math_compare.cpp +++ b/sycl/test-e2e/syclcompat/math/math_compare.cpp @@ -56,7 +56,7 @@ template void test_compare() { constexpr syclcompat::dim3 grid{1}; constexpr syclcompat::dim3 threads{1}; - constexpr ValueT op1 = static_cast(1.0); + const ValueT op1 = static_cast(1.0); ValueT op2 = sycl::nan(static_cast(0)); // 1.0 == 1.0 -> True @@ -96,13 +96,14 @@ void compare_not_equal_vec_kernel(Container *a, Container *b, Container *r) { *r = syclcompat::compare(*a, *b, std::not_equal_to<>()); } -template void test_compare_vec() { +template