diff --git a/sycl/doc/syclcompat/README.md b/sycl/doc/syclcompat/README.md index 492244aef0063..2bd8aca1cd328 100644 --- a/sycl/doc/syclcompat/README.md +++ b/sycl/doc/syclcompat/README.md @@ -1874,17 +1874,11 @@ template inline typename std::enable_if_t, double> pow(const ValueT a, const ValueU b); -template -inline std::enable_if_t || - std::is_same_v, - ValueT> -relu(const ValueT a); +template inline ValueT relu(const ValueT a); -template -inline std::enable_if_t || - std::is_same_v, - sycl::vec> -relu(const sycl::vec a); +template +inline sycl::vec +relu(const sycl::vec a); template inline std::enable_if_t || @@ -1987,9 +1981,12 @@ inline dot_product_acc_t dp4a(T1 a, T2 b, `vectorized_binary` computes the `BinaryOperation` for two operands, with each value treated as a vector type. `vectorized_unary` offers the same -interface for operations with a single operand. +interface for operations with a single operand. `vectorized_ternary` offers the +interface for three operands with two `BinaryOperation`. The implemented `BinaryOperation`s are `abs_diff`, `add_sat`, `rhadd`, `hadd`, `maximum`, `minimum`, and `sub_sat`. +And the `vectorized_with_pred` offers the `BinaryOperation` for two operands, +meanwihle provides the pred of high/low halfword operation. ```cpp namespace syclcompat { @@ -2004,7 +2001,19 @@ struct abs { template inline unsigned vectorized_binary(unsigned a, unsigned b, - const BinaryOperation binary_op); + const BinaryOperation binary_op, + bool need_relu = false); + +template +inline unsigned vectorized_ternary(unsigned a, unsigned b, unsigned c, + const BinaryOperation1 binary_op1, + const BinaryOperation2 binary_op2, + bool need_relu = false); + +template +inline unsigned vectorized_with_pred(unsigned a, unsigned b, + const BinaryOperation binary_op, + bool *pred_hi, bool *pred_lo); // A sycl::abs_diff wrapper functor. struct abs_diff { @@ -2030,11 +2039,15 @@ struct hadd { struct maximum { template auto operator()(const ValueT x, const ValueT y) const; + template + auto operator()(const ValueT x, const ValueT y, bool *pred) const; }; // A sycl::min wrapper functor. struct minimum { template auto operator()(const ValueT x, const ValueT y) const; + template + auto operator()(const ValueT x, const ValueT y, bool *pred) const; }; // A sycl::sub_sat wrapper functor. struct sub_sat { diff --git a/sycl/include/syclcompat/math.hpp b/sycl/include/syclcompat/math.hpp index a3ee2b2085788..b0b8a93d6697c 100644 --- a/sycl/include/syclcompat/math.hpp +++ b/sycl/include/syclcompat/math.hpp @@ -856,23 +856,24 @@ pow(const ValueT a, const ValueU b) { /// Performs relu saturation. /// \param [in] a The input value /// \returns the relu saturation result -template -inline std::enable_if_t, ValueT> -relu(const ValueT a) { - if (!detail::isnan(a) && a < ValueT(0)) +template inline ValueT relu(const ValueT a) { + if constexpr (syclcompat::is_floating_point_v) + if (detail::isnan(a)) + return a; + if (a < ValueT(0)) return ValueT(0); return a; } -template -inline std::enable_if_t, - sycl::vec> -relu(const sycl::vec a) { - return {relu(a[0]), relu(a[1])}; +template +inline sycl::vec +relu(const sycl::vec a) { + sycl::vec ret; + for (int i = 0; i < NumElements; ++i) + ret[i] = relu(a[i]); + return ret; } template -inline std::enable_if_t, - sycl::marray> -relu(const sycl::marray a) { +inline sycl::marray relu(const sycl::marray a) { return {relu(a[0]), relu(a[1])}; } @@ -990,6 +991,10 @@ struct maximum { auto operator()(const ValueT x, const ValueT y) const { return sycl::max(x, y); } + template + auto operator()(const ValueT x, const ValueT y, bool *pred) const { + return (x >= y) ? ((*pred = true), x) : ((*pred = false), y); + } }; /// A sycl::min wrapper functors. @@ -998,6 +1003,10 @@ struct minimum { auto operator()(const ValueT x, const ValueT y) const { return sycl::min(x, y); } + template + auto operator()(const ValueT x, const ValueT y, bool *pred) const { + return (x <= y) ? ((*pred = true), x) : ((*pred = false), y); + } }; /// A sycl::sub_sat wrapper functors. @@ -1037,19 +1046,76 @@ struct average { /// \tparam [in] BinaryOperation The binary operation class /// \param [in] a The first value /// \param [in] b The second value +/// \param [in] binary_op The operation to do with the two values +/// \param [in] need_relu Whether the result need relu saturation /// \returns The vectorized binary operation value of the two values template inline unsigned vectorized_binary(unsigned a, unsigned b, - const BinaryOperation binary_op) { + const BinaryOperation binary_op, + bool need_relu = false) { sycl::vec v0{a}, v1{b}; auto v2 = v0.as(); auto v3 = v1.as(); auto v4 = detail::vectorized_binary()(v2, v3, binary_op); + if (need_relu) + v4 = relu(v4); v0 = v4.template as>(); return v0; } +/// Compute two vectorized binary operation value with pred for three values, +/// with each value treated as a 2 \p T type elements vector type. +/// +/// \tparam [in] VecT The type of the vector +/// \tparam [in] BinaryOperation1 The first binary operation class +/// \tparam [in] BinaryOperation2 The second binary operation class +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \param [in] binary_op1 The first operation to do with the first two values +/// \param [in] binary_op2 The second operation to do with the third values +/// \param [in] need_relu Whether the result need relu saturation +/// \returns The two vectorized binary operation value of the three values +template +inline unsigned vectorized_ternary(unsigned a, unsigned b, unsigned c, + const BinaryOperation1 binary_op1, + const BinaryOperation2 binary_op2, + bool need_relu = false) { + const auto v1 = sycl::vec(a).as(); + const auto v2 = sycl::vec(b).as(); + const auto v3 = sycl::vec(c).as(); + auto v4 = + detail::vectorized_binary()(v1, v2, binary_op1); + v4 = detail::vectorized_binary()(v4, v3, binary_op2); + if (need_relu) + v4 = relu(v4); + return v4.template as>(); +} + +/// Compute vectorized binary operation value with pred for two values, with +/// each value treated as a 2 \p T type elements vector type. +/// +/// \tparam [in] VecT The type of the vector +/// \tparam [in] BinaryOperation The binary operation class +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] binary_op The operation with pred to do with the two values +/// \param [out] pred_hi The pred pointer that pass into high halfword operation +/// \param [out] pred_lo The pred pointer that pass into low halfword operation +/// \returns The vectorized binary operation value of the two values +template +inline unsigned vectorized_binary_with_pred(unsigned a, unsigned b, + const BinaryOperation binary_op, + bool *pred_hi, bool *pred_lo) { + auto v1 = sycl::vec(a).as(); + auto v2 = sycl::vec(b).as(); + VecT ret; + ret[0] = binary_op(v1[0], v2[0], pred_lo); + ret[1] = binary_op(v1[1], v2[1], pred_hi); + return ret.template as>(); +} + template using dot_product_acc_t = std::conditional_t && std::is_unsigned_v, diff --git a/sycl/test-e2e/syclcompat/math/math_fixt.hpp b/sycl/test-e2e/syclcompat/math/math_fixt.hpp index cacd6ea1fb32c..4647142da6c61 100644 --- a/sycl/test-e2e/syclcompat/math/math_fixt.hpp +++ b/sycl/test-e2e/syclcompat/math/math_fixt.hpp @@ -134,6 +134,8 @@ class BinaryOpTestLauncher : OpTestLauncher { ValueT *op1_; ValueU *op2_; ResultT res_h_, *res_; + bool *res_hi_; + bool *res_lo_; public: BinaryOpTestLauncher(const syclcompat::dim3 &grid, @@ -147,6 +149,8 @@ class BinaryOpTestLauncher : OpTestLauncher { op1_ = syclcompat::malloc(data_size); op2_ = syclcompat::malloc(data_size); res_ = syclcompat::malloc(data_size); + res_hi_ = syclcompat::malloc(1); + res_lo_ = syclcompat::malloc(1); }; virtual ~BinaryOpTestLauncher() { @@ -155,6 +159,8 @@ class BinaryOpTestLauncher : OpTestLauncher { syclcompat::free(op1_); syclcompat::free(op2_); syclcompat::free(res_); + syclcompat::free(res_hi_); + syclcompat::free(res_lo_); } template @@ -169,6 +175,37 @@ class BinaryOpTestLauncher : OpTestLauncher { CHECK(ResultT, res_h_, expected); }; + template + void launch_test(ValueT op1, ValueU op2, ResultT expected, bool need_relu) { + if (skip_) + return; + syclcompat::memcpy(op1_, &op1, data_size_); + syclcompat::memcpy(op2_, &op2, data_size_); + syclcompat::launch(grid_, threads_, op1_, op2_, res_, need_relu); + syclcompat::wait(); + syclcompat::memcpy(&res_h_, res_, data_size_); + + CHECK(ResultT, res_h_, expected); + }; + template + void launch_test(ValueT op1, ValueU op2, ResultT expected, bool expected_hi, + bool expected_lo) { + if (skip_) + return; + syclcompat::memcpy(op1_, &op1, data_size_); + syclcompat::memcpy(op2_, &op2, data_size_); + syclcompat::launch(grid_, threads_, op1_, op2_, res_, res_hi_, + res_lo_); + syclcompat::wait(); + syclcompat::memcpy(&res_h_, res_, data_size_); + bool res_hi_h_, res_lo_h_; + syclcompat::memcpy(&res_hi_h_, res_hi_, 1); + syclcompat::memcpy(&res_lo_h_, res_lo_, 1); + + CHECK(ResultT, res_h_, expected); + assert(res_hi_h_ == expected_hi); + assert(res_lo_h_ == expected_lo); + }; }; template @@ -208,3 +245,54 @@ class UnaryOpTestLauncher : OpTestLauncher { CHECK(ResultT, res_h_, expected); } }; + +// Templated ResultT to support both arithmetic and boolean operators +template > +class TernaryOpTestLauncher : OpTestLauncher { +protected: + ValueT *op1_; + ValueU *op2_; + ValueV *op3_; + ResultT res_h_, *res_; + +public: + TernaryOpTestLauncher(const syclcompat::dim3 &grid, + const syclcompat::dim3 &threads, + const size_t data_size = 1) + : OpTestLauncher{grid, threads, data_size, + should_skip()( + syclcompat::get_current_device())} { + if (skip_) + return; + op1_ = syclcompat::malloc(data_size); + op2_ = syclcompat::malloc(data_size); + op3_ = syclcompat::malloc(data_size); + res_ = syclcompat::malloc(data_size); + }; + + virtual ~TernaryOpTestLauncher() { + if (skip_) + return; + syclcompat::free(op1_); + syclcompat::free(op2_); + syclcompat::free(op3_); + syclcompat::free(res_); + } + + template + void launch_test(ValueT op1, ValueU op2, ValueV op3, ResultT expected, + bool need_relu = false) { + if (skip_) + return; + syclcompat::memcpy(op1_, &op1, data_size_); + syclcompat::memcpy(op2_, &op2, data_size_); + syclcompat::memcpy(op3_, &op3, data_size_); + syclcompat::launch(grid_, threads_, op1_, op2_, op3_, res_, + need_relu); + syclcompat::wait(); + syclcompat::memcpy(&res_h_, res_, data_size_); + + CHECK(ResultT, res_h_, expected); + }; +}; diff --git a/sycl/test-e2e/syclcompat/math/math_ops.cpp b/sycl/test-e2e/syclcompat/math/math_ops.cpp index d52d9c60d8ded..baaa0f7f15210 100644 --- a/sycl/test-e2e/syclcompat/math/math_ops.cpp +++ b/sycl/test-e2e/syclcompat/math/math_ops.cpp @@ -226,8 +226,10 @@ template void test_syclcompat_relu() { UnaryOpTestLauncher(grid, threads) .template launch_test>(op1, res1); - const ValueT op2 = static_cast(-3); - const ValueT res2 = static_cast(0); + const ValueT op2 = std::is_signed_v ? static_cast(-3) + : static_cast(2); + const ValueT res2 = std::is_signed_v ? static_cast(0) + : static_cast(2); UnaryOpTestLauncher(grid, threads) .template launch_test>(op2, res2); @@ -374,7 +376,7 @@ int main() { test_syclcompat_pow(); test_syclcompat_pow(); - INSTANTIATE_ALL_TYPES(fp_type_list, test_syclcompat_relu); + INSTANTIATE_ALL_TYPES(value_type_list, test_syclcompat_relu); INSTANTIATE_ALL_TYPES(fp_type_list_no_bfloat16, test_syclcompat_cbrt); INSTANTIATE_ALL_CONTAINER_TYPES(fp_type_list, sycl::vec, test_isnan); diff --git a/sycl/test-e2e/syclcompat/math/math_vectorized.cpp b/sycl/test-e2e/syclcompat/math/math_vectorized.cpp index a7870050e2b66..9c57c88ce445b 100644 --- a/sycl/test-e2e/syclcompat/math/math_vectorized.cpp +++ b/sycl/test-e2e/syclcompat/math/math_vectorized.cpp @@ -31,76 +31,177 @@ #include "math_fixt.hpp" template -void vectorized_binary_kernel(ValueT *a, ValueT *b, unsigned *r) { - unsigned ua = static_cast(*a); - unsigned ub = static_cast(*b); - *r = syclcompat::vectorized_binary(ua, ub, BinaryOp()); +void vectorized_binary_kernel(unsigned *a, unsigned *b, unsigned *r, + bool need_relu) { + *r = syclcompat::vectorized_binary(*a, *b, BinaryOp(), need_relu); } template -void test_vectorized_binary(ValueT op1, ValueT op2, unsigned expected) { +void test_vectorized_binary(unsigned op1, unsigned op2, unsigned expected, + bool need_relu = false) { std::cout << __PRETTY_FUNCTION__ << std::endl; constexpr syclcompat::dim3 grid{1}; constexpr syclcompat::dim3 threads{1}; - BinaryOpTestLauncher(grid, threads) + BinaryOpTestLauncher(grid, threads) .template launch_test>( - op1, op2, expected); + op1, op2, expected, need_relu); } template -void vectorized_unary_kernel(ValueT *a, unsigned *r) { - unsigned ua = static_cast(*a); - *r = syclcompat::vectorized_unary(ua, UnaryOp()); +void vectorized_unary_kernel(unsigned *a, unsigned *r) { + *r = syclcompat::vectorized_unary(*a, UnaryOp()); } template -void test_vectorized_unary(ValueT op1, unsigned expected) { +void test_vectorized_unary(unsigned op1, unsigned expected) { std::cout << __PRETTY_FUNCTION__ << std::endl; constexpr syclcompat::dim3 grid{1}; constexpr syclcompat::dim3 threads{1}; - UnaryOpTestLauncher(grid, threads) + UnaryOpTestLauncher(grid, threads) .template launch_test>(op1, expected); } template -void vectorized_sum_abs_diff_kernel(ValueT *a, ValueT *b, unsigned *r) { - unsigned ua = static_cast(*a); - unsigned ub = static_cast(*b); - - *r = syclcompat::vectorized_sum_abs_diff(ua, ub); +void vectorized_sum_abs_diff_kernel(unsigned *a, unsigned *b, unsigned *r) { + *r = syclcompat::vectorized_sum_abs_diff(*a, *b); } template -void test_vectorized_sum_abs_diff(ValueT op1, ValueT op2, unsigned expected) { +void test_vectorized_sum_abs_diff(unsigned op1, unsigned op2, + unsigned expected) { std::cout << __PRETTY_FUNCTION__ << std::endl; constexpr syclcompat::dim3 grid{1}; constexpr syclcompat::dim3 threads{1}; - BinaryOpTestLauncher(grid, threads) + BinaryOpTestLauncher(grid, threads) .template launch_test>(op1, op2, expected); } +template +void vectorized_ternary_kernel(unsigned *a, unsigned *b, unsigned *c, + unsigned *r, bool need_relu) { + *r = syclcompat::vectorized_ternary(*a, *b, *c, BinaryOp1(), + BinaryOp2(), need_relu); +} + +template +void test_vectorized_ternary(unsigned op1, unsigned op2, unsigned op3, + unsigned expected, bool need_relu = false) { + std::cout << __PRETTY_FUNCTION__ << std::endl; + constexpr syclcompat::dim3 grid{1}; + constexpr syclcompat::dim3 threads{1}; + + TernaryOpTestLauncher(grid, threads) + .template launch_test< + vectorized_ternary_kernel>( + op1, op2, op3, expected, need_relu); +} + +template +void vectorized_binary_with_pred_kernel(unsigned *a, unsigned *b, unsigned *r, + bool *pred_hi, bool *pred_lo) { + *r = syclcompat::vectorized_binary_with_pred(*a, *b, BinaryOp(), + pred_hi, pred_lo); +} + +template +void test_vectorized_binary_with_pred(unsigned op1, unsigned op2, + unsigned expected, bool pred_hi, + bool pred_lo) { + std::cout << __PRETTY_FUNCTION__ << std::endl; + constexpr syclcompat::dim3 grid{1}; + constexpr syclcompat::dim3 threads{1}; + + BinaryOpTestLauncher(grid, threads) + .template launch_test< + vectorized_binary_with_pred_kernel>( + op1, op2, expected, pred_hi, pred_lo); +} + int main() { - test_vectorized_binary(0x00010002, 0x00040002, - 0x00030000); - test_vectorized_binary(0x00020002, 0xFFFDFFFF, - 0xFFFF0001); - test_vectorized_binary(0x00010008, 0x00020001, - 0x00020005); - test_vectorized_binary(0x00010003, 0x00020005, - 0x00010004); - test_vectorized_binary(0x0FFF0000, 0x00000FFF, - 0x0FFF0FFF); - test_vectorized_binary(0x0FFF0000, 0x00000FFF, - 0x00000000); - test_vectorized_binary(0xFFFB0005, 0x00030008, - 0xFFF8FFFD); - test_vectorized_unary(0xFFFBFFFD, 0x00050003); - test_vectorized_sum_abs_diff(0x00010002, 0x00040002, 0x00000003); + test_vectorized_binary( + 0x00010002, 0x00040002, 0x00030000); + test_vectorized_binary( + 0x00020002, 0xFFFDFFFF, 0xFFFF0001); + test_vectorized_binary( + 0x00010008, 0x00020001, 0x00020005); + test_vectorized_binary(0x00010003, 0x00020005, + 0x00010004); + test_vectorized_binary( + 0x0FFF0000, 0x00000FFF, 0x0FFF0FFF); + test_vectorized_binary( + 0x0FFF0000, 0x00000FFF, 0x00000000); + test_vectorized_binary( + 0xFFFB0005, 0x00030008, 0xFFF8FFFD); + test_vectorized_binary( + 0x00010002, 0x00040002, 0x00030000, true); + test_vectorized_binary( + 0x00020002, 0xFFFDFFFF, 0x00000001, true); + test_vectorized_binary( + 0x00010008, 0x00020001, 0x00020005, true); + test_vectorized_binary(0x00010003, 0x00020005, + 0x00010004, true); + test_vectorized_binary( + 0x0FFF0000, 0x00000FFF, 0x0FFF0FFF, true); + test_vectorized_binary( + 0x0FFF0000, 0x00000FFF, 0x00000000, true); + test_vectorized_binary( + 0xFFFB0005, 0x00030008, 0x00000000, true); + test_vectorized_unary(0xFFFBFFFD, 0x00050003); + test_vectorized_sum_abs_diff(0x00010002, 0x00040002, + 0x00000003); + test_vectorized_ternary, syclcompat::maximum, sycl::ushort2>( + 0x00010002, 0x00040002, 0x00080004, 0x00080004); + test_vectorized_ternary, syclcompat::maximum, sycl::ushort2>( + 0x00010002, 0x00040002, 0x00080004, 0x00080004, true); + test_vectorized_ternary, syclcompat::minimum, sycl::ushort2>( + 0x00010002, 0x00040002, 0x00080004, 0x00050004); + test_vectorized_ternary, syclcompat::minimum, sycl::ushort2>( + 0x00010002, 0x00040002, 0x00080004, 0x00050004, true); + test_vectorized_ternary(0x00010002, 0x00040002, 0x00080004, + 0x00080004); + test_vectorized_ternary(0x00010002, 0x00040002, 0x00080004, + 0x00080004, true); + test_vectorized_ternary(0x00010002, 0x00040002, 0x00080004, + 0x00010002); + test_vectorized_ternary(0x00010002, 0x00040002, 0x00080004, + 0x00010002, true); + test_vectorized_ternary, syclcompat::maximum, sycl::short2>( + 0x80010002, 0x00040002, 0x00080004, 0x00080004); + test_vectorized_ternary, syclcompat::maximum, sycl::short2>( + 0x80010002, 0x00040002, 0x00080004, 0x00080004, true); + test_vectorized_ternary, syclcompat::minimum, sycl::short2>( + 0x80010002, 0x00040002, 0x00080004, 0x80050004); + test_vectorized_ternary, syclcompat::minimum, sycl::short2>( + 0x80010002, 0x00040002, 0x00080004, 0x00000004, true); + test_vectorized_ternary(0x80010002, 0x00040002, 0x00080004, + 0x00080004); + test_vectorized_ternary(0x80010002, 0x00040002, 0x00080004, + 0x00080004, true); + test_vectorized_ternary(0x80010002, 0x00040002, 0x00080004, + 0x80010002); + test_vectorized_ternary(0x80010002, 0x00040002, 0x00080004, + 0x00000002, true); + test_vectorized_binary_with_pred( + 0x80010002, 0x00040002, 0x00040002, false, true); + test_vectorized_binary_with_pred( + 0x80010002, 0x00040002, 0x80010002, true, true); + test_vectorized_binary_with_pred( + 0x80010002, 0x00040002, 0x80010002, true, true); + test_vectorized_binary_with_pred( + 0x80010002, 0x00040002, 0x00040002, false, true); return 0; }