From 45b4436bef2f1f18caa7d8facf2e7ad7b73763ba Mon Sep 17 00:00:00 2001 From: "Tang, Jiajun" Date: Sun, 29 Sep 2024 16:54:45 +0800 Subject: [PATCH 1/4] [SYCL][COMPAT] Add vectorized_ternary and vectorized_with_pred. Add vectorized_ternary and vectorized_with_pred. Update relu and vectorized_binary. Signed-off-by: Tang, Jiajun jiajun.tang@intel.com --- sycl/include/syclcompat/math.hpp | 79 +++++++++++++++++-- sycl/test-e2e/syclcompat/math/math_fixt.hpp | 49 ++++++++++++ .../syclcompat/math/math_vectorized.cpp | 66 ++++++++++++++++ 3 files changed, 189 insertions(+), 5 deletions(-) diff --git a/sycl/include/syclcompat/math.hpp b/sycl/include/syclcompat/math.hpp index a3ee2b2085788..56175991e1383 100644 --- a/sycl/include/syclcompat/math.hpp +++ b/sycl/include/syclcompat/math.hpp @@ -863,11 +863,14 @@ relu(const ValueT a) { return ValueT(0); return a; } -template +template inline std::enable_if_t, - sycl::vec> -relu(const sycl::vec a) { - return {relu(a[0]), relu(a[1])}; + sycl::vec> +relu(const sycl::vec a) { + sycl::vec ret; + for (int i = 0; i < NumElements; ++i) + ret[i] = relu(a[i]); + return ret; } template inline std::enable_if_t, @@ -990,6 +993,10 @@ struct maximum { auto operator()(const ValueT x, const ValueT y) const { return sycl::max(x, y); } + template + auto operator()(const T x, const T y, bool *pred) const { + return (x >= y) ? ((*pred = true), x) : ((*pred = false), y); + } }; /// A sycl::min wrapper functors. @@ -998,6 +1005,10 @@ struct minimum { auto operator()(const ValueT x, const ValueT y) const { return sycl::min(x, y); } + template + auto operator()(const T x, const T y, bool *pred) const { + return (x <= y) ? ((*pred = true), x) : ((*pred = false), y); + } }; /// A sycl::sub_sat wrapper functors. @@ -1037,19 +1048,77 @@ struct average { /// \tparam [in] BinaryOperation The binary operation class /// \param [in] a The first value /// \param [in] b The second value +/// \param [in] binary_op The operation to do with the two values +/// \param [in] need_relu Whether the result need relu saturation /// \returns The vectorized binary operation value of the two values template inline unsigned vectorized_binary(unsigned a, unsigned b, - const BinaryOperation binary_op) { + const BinaryOperation binary_op, + bool need_relu = false) { sycl::vec v0{a}, v1{b}; auto v2 = v0.as(); auto v3 = v1.as(); auto v4 = detail::vectorized_binary()(v2, v3, binary_op); + if (need_relu) + v4 = relu(v4); v0 = v4.template as>(); return v0; } +/// Compute two vectorized binary operation value with pred for three values, +/// with each value treated as a 2 \p T type elements vector type. +/// +/// \tparam [in] VecT The type of the vector +/// \tparam [in] BinaryOperation1 The first binary operation class +/// \tparam [in] BinaryOperation2 The second binary operation class +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] c The third value +/// \param [in] binary_op1 The first operation to do with the first two values +/// \param [in] binary_op2 The second operation to do with the third values +/// \param [in] need_relu Whether the result need relu saturation +/// \returns The two vectorized binary operation value of the three values +template +inline unsigned vectorized_ternary(unsigned a, unsigned b, unsigned c, + const BinaryOperation1 binary_op1, + const BinaryOperation2 binary_op2, + bool need_relu = false) { + const auto v1 = sycl::vec(a).as(); + const auto v2 = sycl::vec(b).as(); + const auto v3 = sycl::vec(c).as(); + auto temp = + detail::vectorized_binary()(v1, v2, binary_op1); + temp = + detail::vectorized_binary()(temp, v3, binary_op2); + if (need_relu) + temp = relu(temp); + return temp.template as>(); +} + +/// Compute vectorized binary operation value with pred for two values, with +/// each value treated as a 2 \p T type elements vector type. +/// +/// \tparam [in] T The type of elements type of the vector +/// \tparam [in] BinaryOperation The binary operation class +/// \param [in] a The first value +/// \param [in] b The second value +/// \param [in] binary_op The operation with pred to do with the two values +/// \param [in] pred_hi The pred pointer that pass into high halfword operation +/// \param [in] pred_lo The pred pointer that pass into low halfword operation +/// \returns The vectorized binary operation value of the two values +template +inline unsigned vectorized_with_pred(unsigned a, unsigned b, + const BinaryOperation binary_op, + bool *pred_hi, bool *pred_lo) { + auto v1 = sycl::vec(a).as>(); + auto v2 = sycl::vec(b).as>(); + sycl::vec ret; + ret[0] = binary_op(v1[0], v2[0], pred_lo); + ret[1] = binary_op(v1[1], v2[1], pred_hi); + return ret.template as>(); +} + template using dot_product_acc_t = std::conditional_t && std::is_unsigned_v, diff --git a/sycl/test-e2e/syclcompat/math/math_fixt.hpp b/sycl/test-e2e/syclcompat/math/math_fixt.hpp index cacd6ea1fb32c..539d7e9d7cf38 100644 --- a/sycl/test-e2e/syclcompat/math/math_fixt.hpp +++ b/sycl/test-e2e/syclcompat/math/math_fixt.hpp @@ -208,3 +208,52 @@ class UnaryOpTestLauncher : OpTestLauncher { CHECK(ResultT, res_h_, expected); } }; + +// Templated ResultT to support both arithmetic and boolean operators +template > +class TernaryOpTestLauncher : OpTestLauncher { +protected: + ValueT *op1_; + ValueU *op2_; + ValueV *op2_; + ResultT res_h_, *res_; + +public: + TernaryOpTestLauncher(const syclcompat::dim3 &grid, + const syclcompat::dim3 &threads, + const size_t data_size = 1) + : OpTestLauncher{ + grid, threads, data_size, + should_skip()(syclcompat::get_current_device())} { + if (skip_) + return; + op1_ = syclcompat::malloc(data_size); + op2_ = syclcompat::malloc(data_size); + op3_ = syclcompat::malloc(data_size); + res_ = syclcompat::malloc(data_size); + }; + + virtual ~TernaryOpTestLauncher() { + if (skip_) + return; + syclcompat::free(op1_); + syclcompat::free(op2_); + syclcompat::free(op3_); + syclcompat::free(res_); + } + + template + void launch_test(ValueT op1, ValueU op2, ValueU op3, ResultT expected) { + if (skip_) + return; + syclcompat::memcpy(op1_, &op1, data_size_); + syclcompat::memcpy(op2_, &op2, data_size_); + syclcompat::memcpy(op3_, &op3, data_size_); + syclcompat::launch(grid_, threads_, op1_, op2_, op3_, res_); + syclcompat::wait(); + syclcompat::memcpy(&res_h_, res_, data_size_); + + CHECK(ResultT, res_h_, expected); + }; +}; diff --git a/sycl/test-e2e/syclcompat/math/math_vectorized.cpp b/sycl/test-e2e/syclcompat/math/math_vectorized.cpp index a7870050e2b66..16a4bf92a70fc 100644 --- a/sycl/test-e2e/syclcompat/math/math_vectorized.cpp +++ b/sycl/test-e2e/syclcompat/math/math_vectorized.cpp @@ -84,6 +84,51 @@ void test_vectorized_sum_abs_diff(ValueT op1, ValueT op2, unsigned expected) { expected); } +template +void vectorized_ternary_kernel(ValueT *a, ValueT *b, ValueT *c, unsigned *r, + bool need_relu) { + unsigned ua = static_cast(*a); + unsigned ub = static_cast(*b); + unsigned uc = static_cast(*c); + *r = syclcompat::vectorized_ternary(ua, ub, uc, BinaryOp1(), + BinaryOp2(), need_relu); +} + +template +void test_vectorized_ternary(ValueT op1, ValueT op2, ValueT op3, + unsigned expected, bool need_relu = false) { + std::cout << __PRETTY_FUNCTION__ << std::endl; + constexpr syclcompat::dim3 grid{1}; + constexpr syclcompat::dim3 threads{1}; + + TernaryOpTestLauncher(grid, threads) + .template launch_test< + vectorized_binary_kernel>( + op1, op2, op3, expected, need_relu); +} + +template +void vectorized_with_pred_kernel(ValueT *a, ValueT *b, unsigned *r, + bool *pred_hi, bool *pred_lo) { + unsigned ua = static_cast(*a); + unsigned ub = static_cast(*b); + + *r = syclcompat::vectorized_with_pred(ua, ub, BinaryOp(), pred_hi, + pred_lo); +} + +template +void test_vectorized_with_pred(ValueT op1, ValueT op2, unsigned expected, + bool *pred_hi, bool *pred_lo) { + std::cout << __PRETTY_FUNCTION__ << std::endl; + constexpr syclcompat::dim3 grid{1}; + constexpr syclcompat::dim3 threads{1}; + + BinaryOpTestLauncher(grid, threads) + .template launch_test>( + op1, op2, expected, pred_hi, pred_lo); +} + int main() { test_vectorized_binary(0x00010002, 0x00040002, 0x00030000); @@ -101,6 +146,27 @@ int main() { 0xFFF8FFFD); test_vectorized_unary(0xFFFBFFFD, 0x00050003); test_vectorized_sum_abs_diff(0x00010002, 0x00040002, 0x00000003); + test_vectorized_ternary, syclcompat::maximum, uint32_t>( + 0x00010002, 0x00040002, 0x00080004, 0x00030000); + test_vectorized_ternary, syclcompat::maximum, uint32_t>( + 0x00010002, 0x00040002, 0x00080004, 0x00030000, true); + test_vectorized_ternary, syclcompat::minimum, uint32_t>( + 0x00010002, 0x00040002, 0x00080004, 0x00030000); + test_vectorized_ternary, syclcompat::minimum, uint32_t>( + 0x00010002, 0x00040002, 0x00080004, 0x00030000, true); + test_vectorized_ternary( + 0x00010002, 0x00040002, 0x00080004, 0x00030000); + test_vectorized_ternary( + 0x00010002, 0x00040002, 0x00080004, 0x00030000, true); + test_vectorized_ternary( + 0x00010002, 0x00040002, 0x00080004, 0x00030000); + test_vectorized_ternary( + 0x00010002, 0x00040002, 0x00080004, 0x00030000, true); + bool pred_hi, bool pred_lo; + test_vectorized_with_pred( + 0x00010002, 0x00040002, 0x00030000, &pred_hi, &pred_lo); + test_vectorized_with_pred( + 0x00010002, 0x00040002, 0x00030000, &pred_hi, &pred_lo); return 0; } From 827e50beaa86820a364f2839f2ea99c1def33661 Mon Sep 17 00:00:00 2001 From: "Tang, Jiajun" Date: Tue, 8 Oct 2024 10:00:30 +0800 Subject: [PATCH 2/4] Fix comment. --- sycl/doc/syclcompat/README.md | 37 ++++++++---- sycl/include/syclcompat/math.hpp | 40 ++++++------- sycl/test-e2e/syclcompat/math/math_fixt.hpp | 45 ++++++++++++++- .../syclcompat/math/math_vectorized.cpp | 56 ++++++++++++------- 4 files changed, 123 insertions(+), 55 deletions(-) diff --git a/sycl/doc/syclcompat/README.md b/sycl/doc/syclcompat/README.md index 492244aef0063..2bd8aca1cd328 100644 --- a/sycl/doc/syclcompat/README.md +++ b/sycl/doc/syclcompat/README.md @@ -1874,17 +1874,11 @@ template inline typename std::enable_if_t, double> pow(const ValueT a, const ValueU b); -template -inline std::enable_if_t || - std::is_same_v, - ValueT> -relu(const ValueT a); +template inline ValueT relu(const ValueT a); -template -inline std::enable_if_t || - std::is_same_v, - sycl::vec> -relu(const sycl::vec a); +template +inline sycl::vec +relu(const sycl::vec a); template inline std::enable_if_t || @@ -1987,9 +1981,12 @@ inline dot_product_acc_t dp4a(T1 a, T2 b, `vectorized_binary` computes the `BinaryOperation` for two operands, with each value treated as a vector type. `vectorized_unary` offers the same -interface for operations with a single operand. +interface for operations with a single operand. `vectorized_ternary` offers the +interface for three operands with two `BinaryOperation`. The implemented `BinaryOperation`s are `abs_diff`, `add_sat`, `rhadd`, `hadd`, `maximum`, `minimum`, and `sub_sat`. +And the `vectorized_with_pred` offers the `BinaryOperation` for two operands, +meanwihle provides the pred of high/low halfword operation. ```cpp namespace syclcompat { @@ -2004,7 +2001,19 @@ struct abs { template inline unsigned vectorized_binary(unsigned a, unsigned b, - const BinaryOperation binary_op); + const BinaryOperation binary_op, + bool need_relu = false); + +template +inline unsigned vectorized_ternary(unsigned a, unsigned b, unsigned c, + const BinaryOperation1 binary_op1, + const BinaryOperation2 binary_op2, + bool need_relu = false); + +template +inline unsigned vectorized_with_pred(unsigned a, unsigned b, + const BinaryOperation binary_op, + bool *pred_hi, bool *pred_lo); // A sycl::abs_diff wrapper functor. struct abs_diff { @@ -2030,11 +2039,15 @@ struct hadd { struct maximum { template auto operator()(const ValueT x, const ValueT y) const; + template + auto operator()(const ValueT x, const ValueT y, bool *pred) const; }; // A sycl::min wrapper functor. struct minimum { template auto operator()(const ValueT x, const ValueT y) const; + template + auto operator()(const ValueT x, const ValueT y, bool *pred) const; }; // A sycl::sub_sat wrapper functor. struct sub_sat { diff --git a/sycl/include/syclcompat/math.hpp b/sycl/include/syclcompat/math.hpp index 56175991e1383..9207cbd4fdd04 100644 --- a/sycl/include/syclcompat/math.hpp +++ b/sycl/include/syclcompat/math.hpp @@ -856,18 +856,19 @@ pow(const ValueT a, const ValueU b) { /// Performs relu saturation. /// \param [in] a The input value /// \returns the relu saturation result -template -inline std::enable_if_t, ValueT> -relu(const ValueT a) { - if (!detail::isnan(a) && a < ValueT(0)) +template inline ValueT relu(const ValueT a) { + if constexpr (syclcompat::is_floating_point_v || + std::is_same_v) + if (!detail::isnan(a) && a < ValueT(0)) + return ValueT(0); + if (a < ValueT(0)) return ValueT(0); return a; } template -inline std::enable_if_t, - sycl::vec> +inline sycl::vec relu(const sycl::vec a) { - sycl::vec ret; + sycl::vec ret; for (int i = 0; i < NumElements; ++i) ret[i] = relu(a[i]); return ret; @@ -993,8 +994,8 @@ struct maximum { auto operator()(const ValueT x, const ValueT y) const { return sycl::max(x, y); } - template - auto operator()(const T x, const T y, bool *pred) const { + template + auto operator()(const ValueT x, const ValueT y, bool *pred) const { return (x >= y) ? ((*pred = true), x) : ((*pred = false), y); } }; @@ -1005,8 +1006,8 @@ struct minimum { auto operator()(const ValueT x, const ValueT y) const { return sycl::min(x, y); } - template - auto operator()(const T x, const T y, bool *pred) const { + template + auto operator()(const ValueT x, const ValueT y, bool *pred) const { return (x <= y) ? ((*pred = true), x) : ((*pred = false), y); } }; @@ -1087,13 +1088,12 @@ inline unsigned vectorized_ternary(unsigned a, unsigned b, unsigned c, const auto v1 = sycl::vec(a).as(); const auto v2 = sycl::vec(b).as(); const auto v3 = sycl::vec(c).as(); - auto temp = + auto v4 = detail::vectorized_binary()(v1, v2, binary_op1); - temp = - detail::vectorized_binary()(temp, v3, binary_op2); + v4 = detail::vectorized_binary()(v4, v3, binary_op2); if (need_relu) - temp = relu(temp); - return temp.template as>(); + v4 = relu(v4); + return v4.template as>(); } /// Compute vectorized binary operation value with pred for two values, with @@ -1107,13 +1107,13 @@ inline unsigned vectorized_ternary(unsigned a, unsigned b, unsigned c, /// \param [in] pred_hi The pred pointer that pass into high halfword operation /// \param [in] pred_lo The pred pointer that pass into low halfword operation /// \returns The vectorized binary operation value of the two values -template +template inline unsigned vectorized_with_pred(unsigned a, unsigned b, const BinaryOperation binary_op, bool *pred_hi, bool *pred_lo) { - auto v1 = sycl::vec(a).as>(); - auto v2 = sycl::vec(b).as>(); - sycl::vec ret; + auto v1 = sycl::vec(a).as>(); + auto v2 = sycl::vec(b).as>(); + sycl::vec ret; ret[0] = binary_op(v1[0], v2[0], pred_lo); ret[1] = binary_op(v1[1], v2[1], pred_hi); return ret.template as>(); diff --git a/sycl/test-e2e/syclcompat/math/math_fixt.hpp b/sycl/test-e2e/syclcompat/math/math_fixt.hpp index 539d7e9d7cf38..7ba88cf1451a3 100644 --- a/sycl/test-e2e/syclcompat/math/math_fixt.hpp +++ b/sycl/test-e2e/syclcompat/math/math_fixt.hpp @@ -134,6 +134,8 @@ class BinaryOpTestLauncher : OpTestLauncher { ValueT *op1_; ValueU *op2_; ResultT res_h_, *res_; + bool *res_hi_; + bool *res_lo_; public: BinaryOpTestLauncher(const syclcompat::dim3 &grid, @@ -147,6 +149,8 @@ class BinaryOpTestLauncher : OpTestLauncher { op1_ = syclcompat::malloc(data_size); op2_ = syclcompat::malloc(data_size); res_ = syclcompat::malloc(data_size); + res_hi_ = syclcompat::malloc(1); + res_lo_ = syclcompat::malloc(1); }; virtual ~BinaryOpTestLauncher() { @@ -155,6 +159,8 @@ class BinaryOpTestLauncher : OpTestLauncher { syclcompat::free(op1_); syclcompat::free(op2_); syclcompat::free(res_); + syclcompat::free(res_hi_); + syclcompat::free(res_lo_); } template @@ -169,6 +175,37 @@ class BinaryOpTestLauncher : OpTestLauncher { CHECK(ResultT, res_h_, expected); }; + template + void launch_test(ValueT op1, ValueU op2, ResultT expected, bool need_relu) { + if (skip_) + return; + syclcompat::memcpy(op1_, &op1, data_size_); + syclcompat::memcpy(op2_, &op2, data_size_); + syclcompat::launch(grid_, threads_, op1_, op2_, res_, need_relu); + syclcompat::wait(); + syclcompat::memcpy(&res_h_, res_, data_size_); + + CHECK(ResultT, res_h_, expected); + }; + template + void launch_test(ValueT op1, ValueU op2, ResultT expected, bool expected_hi, + bool expected_lo) { + if (skip_) + return; + syclcompat::memcpy(op1_, &op1, data_size_); + syclcompat::memcpy(op2_, &op2, data_size_); + syclcompat::launch(grid_, threads_, op1_, op2_, res_, res_hi_, + res_lo_); + syclcompat::wait(); + syclcompat::memcpy(&res_h_, res_, data_size_); + bool res_hi_h_, res_lo_h_; + syclcompat::memcpy(&res_hi_h_, res_hi_, 1); + syclcompat::memcpy(&res_lo_h_, res_lo_, 1); + + CHECK(ResultT, res_h_, expected); + assert(res_hi_h_ == expected_hi); + assert(res_lo_h_ == expected_lo); + }; }; template @@ -216,7 +253,7 @@ class TernaryOpTestLauncher : OpTestLauncher { protected: ValueT *op1_; ValueU *op2_; - ValueV *op2_; + ValueV *op3_; ResultT res_h_, *res_; public: @@ -244,13 +281,15 @@ class TernaryOpTestLauncher : OpTestLauncher { } template - void launch_test(ValueT op1, ValueU op2, ValueU op3, ResultT expected) { + void launch_test(ValueT op1, ValueU op2, ValueU op3, ResultT expected, + bool need_relu = false) { if (skip_) return; syclcompat::memcpy(op1_, &op1, data_size_); syclcompat::memcpy(op2_, &op2, data_size_); syclcompat::memcpy(op3_, &op3, data_size_); - syclcompat::launch(grid_, threads_, op1_, op2_, op3_, res_); + syclcompat::launch(grid_, threads_, op1_, op2_, op3_, res_, + need_relu); syclcompat::wait(); syclcompat::memcpy(&res_h_, res_, data_size_); diff --git a/sycl/test-e2e/syclcompat/math/math_vectorized.cpp b/sycl/test-e2e/syclcompat/math/math_vectorized.cpp index 16a4bf92a70fc..b2257359d1b58 100644 --- a/sycl/test-e2e/syclcompat/math/math_vectorized.cpp +++ b/sycl/test-e2e/syclcompat/math/math_vectorized.cpp @@ -31,21 +31,24 @@ #include "math_fixt.hpp" template -void vectorized_binary_kernel(ValueT *a, ValueT *b, unsigned *r) { +void vectorized_binary_kernel(ValueT *a, ValueT *b, unsigned *r, + bool need_relu) { unsigned ua = static_cast(*a); unsigned ub = static_cast(*b); - *r = syclcompat::vectorized_binary(ua, ub, BinaryOp()); + *r = syclcompat::vectorized_binary(ua, ub, BinaryOp(), + need_relu); } template -void test_vectorized_binary(ValueT op1, ValueT op2, unsigned expected) { +void test_vectorized_binary(ValueT op1, ValueT op2, unsigned expected, + bool need_relu = false) { std::cout << __PRETTY_FUNCTION__ << std::endl; constexpr syclcompat::dim3 grid{1}; constexpr syclcompat::dim3 threads{1}; BinaryOpTestLauncher(grid, threads) .template launch_test>( - op1, op2, expected); + op1, op2, expected, need_relu); } template @@ -103,7 +106,7 @@ void test_vectorized_ternary(ValueT op1, ValueT op2, ValueT op3, TernaryOpTestLauncher(grid, threads) .template launch_test< - vectorized_binary_kernel>( + vectorized_ternary_kernel>( op1, op2, op3, expected, need_relu); } @@ -117,16 +120,16 @@ void vectorized_with_pred_kernel(ValueT *a, ValueT *b, unsigned *r, pred_lo); } -template +template void test_vectorized_with_pred(ValueT op1, ValueT op2, unsigned expected, - bool *pred_hi, bool *pred_lo) { + bool expected_hi, bool expected_lo) { std::cout << __PRETTY_FUNCTION__ << std::endl; constexpr syclcompat::dim3 grid{1}; constexpr syclcompat::dim3 threads{1}; BinaryOpTestLauncher(grid, threads) - .template launch_test>( - op1, op2, expected, pred_hi, pred_lo); + .template launch_test>( + op1, op2, expected, expected_hi, expected_lo); } int main() { @@ -144,29 +147,42 @@ int main() { 0x00000000); test_vectorized_binary(0xFFFB0005, 0x00030008, 0xFFF8FFFD); + test_vectorized_binary(0x00010002, 0x00040002, + 0x00030000, true); + test_vectorized_binary(0x00020002, 0xFFFDFFFF, + 0x00000001, true); + test_vectorized_binary(0x00010008, 0x00020001, + 0x00020005, true); + test_vectorized_binary(0x00010003, 0x00020005, + 0x00010004, true); + test_vectorized_binary(0x0FFF0000, 0x00000FFF, + 0x0FFF0FFF, true); + test_vectorized_binary(0x0FFF0000, 0x00000FFF, + 0x00000000, true); + test_vectorized_binary(0xFFFB0005, 0x00030008, + 0x00000000, true); test_vectorized_unary(0xFFFBFFFD, 0x00050003); test_vectorized_sum_abs_diff(0x00010002, 0x00040002, 0x00000003); test_vectorized_ternary, syclcompat::maximum, uint32_t>( - 0x00010002, 0x00040002, 0x00080004, 0x00030000); + 0x00010002, 0x00040002, 0x00080004, 0x00080004); test_vectorized_ternary, syclcompat::maximum, uint32_t>( - 0x00010002, 0x00040002, 0x00080004, 0x00030000, true); + 0x00010002, 0x00040002, 0x00080004, 0x00080004, true); test_vectorized_ternary, syclcompat::minimum, uint32_t>( - 0x00010002, 0x00040002, 0x00080004, 0x00030000); + 0x00010002, 0x00040002, 0x00080004, 0x00050004); test_vectorized_ternary, syclcompat::minimum, uint32_t>( - 0x00010002, 0x00040002, 0x00080004, 0x00030000, true); + 0x00010002, 0x00040002, 0x00080004, 0x00050004, true); test_vectorized_ternary( - 0x00010002, 0x00040002, 0x00080004, 0x00030000); + 0x00010002, 0x00040002, 0x00080004, 0x00080004); test_vectorized_ternary( - 0x00010002, 0x00040002, 0x00080004, 0x00030000, true); + 0x00010002, 0x00040002, 0x00080004, 0x00080004, true); test_vectorized_ternary( - 0x00010002, 0x00040002, 0x00080004, 0x00030000); + 0x00010002, 0x00040002, 0x00080004, 0x00010002); test_vectorized_ternary( - 0x00010002, 0x00040002, 0x00080004, 0x00030000, true); - bool pred_hi, bool pred_lo; + 0x00010002, 0x00040002, 0x00080004, 0x00010002, true); test_vectorized_with_pred( - 0x00010002, 0x00040002, 0x00030000, &pred_hi, &pred_lo); + 0x00010002, 0x00040002, 0x00040002, false, true); test_vectorized_with_pred( - 0x00010002, 0x00040002, 0x00030000, &pred_hi, &pred_lo); + 0x00010002, 0x00040002, 0x00010002, true, true); return 0; } From 829c529640ccd1d91bc11216712c63819caa07c0 Mon Sep 17 00:00:00 2001 From: "Tang, Jiajun" Date: Wed, 16 Oct 2024 09:17:04 +0800 Subject: [PATCH 3/4] fix comment. --- sycl/include/syclcompat/math.hpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/sycl/include/syclcompat/math.hpp b/sycl/include/syclcompat/math.hpp index 9207cbd4fdd04..8baa2b0f42679 100644 --- a/sycl/include/syclcompat/math.hpp +++ b/sycl/include/syclcompat/math.hpp @@ -859,8 +859,8 @@ pow(const ValueT a, const ValueU b) { template inline ValueT relu(const ValueT a) { if constexpr (syclcompat::is_floating_point_v || std::is_same_v) - if (!detail::isnan(a) && a < ValueT(0)) - return ValueT(0); + if (detail::isnan(a)) + return a; if (a < ValueT(0)) return ValueT(0); return a; @@ -1104,13 +1104,15 @@ inline unsigned vectorized_ternary(unsigned a, unsigned b, unsigned c, /// \param [in] a The first value /// \param [in] b The second value /// \param [in] binary_op The operation with pred to do with the two values -/// \param [in] pred_hi The pred pointer that pass into high halfword operation -/// \param [in] pred_lo The pred pointer that pass into low halfword operation +/// \param [out] pred_hi The pred pointer that pass into high halfword operation +/// \param [out] pred_lo The pred pointer that pass into low halfword operation /// \returns The vectorized binary operation value of the two values template inline unsigned vectorized_with_pred(unsigned a, unsigned b, const BinaryOperation binary_op, bool *pred_hi, bool *pred_lo) { + static_assert(std::is_same_v || + std::is_same_v); auto v1 = sycl::vec(a).as>(); auto v2 = sycl::vec(b).as>(); sycl::vec ret; From d43849ff3e6ae1fca4c2506318cdb451adaff5c5 Mon Sep 17 00:00:00 2001 From: "Tang, Jiajun" Date: Fri, 18 Oct 2024 16:02:03 +0800 Subject: [PATCH 4/4] Fix comment. --- sycl/include/syclcompat/math.hpp | 25 +-- sycl/test-e2e/syclcompat/math/math_fixt.hpp | 12 +- sycl/test-e2e/syclcompat/math/math_ops.cpp | 8 +- .../syclcompat/math/math_vectorized.cpp | 191 ++++++++++-------- 4 files changed, 126 insertions(+), 110 deletions(-) diff --git a/sycl/include/syclcompat/math.hpp b/sycl/include/syclcompat/math.hpp index 8baa2b0f42679..b0b8a93d6697c 100644 --- a/sycl/include/syclcompat/math.hpp +++ b/sycl/include/syclcompat/math.hpp @@ -857,8 +857,7 @@ pow(const ValueT a, const ValueU b) { /// \param [in] a The input value /// \returns the relu saturation result template inline ValueT relu(const ValueT a) { - if constexpr (syclcompat::is_floating_point_v || - std::is_same_v) + if constexpr (syclcompat::is_floating_point_v) if (detail::isnan(a)) return a; if (a < ValueT(0)) @@ -874,9 +873,7 @@ relu(const sycl::vec a) { return ret; } template -inline std::enable_if_t, - sycl::marray> -relu(const sycl::marray a) { +inline sycl::marray relu(const sycl::marray a) { return {relu(a[0]), relu(a[1])}; } @@ -1099,7 +1096,7 @@ inline unsigned vectorized_ternary(unsigned a, unsigned b, unsigned c, /// Compute vectorized binary operation value with pred for two values, with /// each value treated as a 2 \p T type elements vector type. /// -/// \tparam [in] T The type of elements type of the vector +/// \tparam [in] VecT The type of the vector /// \tparam [in] BinaryOperation The binary operation class /// \param [in] a The first value /// \param [in] b The second value @@ -1107,15 +1104,13 @@ inline unsigned vectorized_ternary(unsigned a, unsigned b, unsigned c, /// \param [out] pred_hi The pred pointer that pass into high halfword operation /// \param [out] pred_lo The pred pointer that pass into low halfword operation /// \returns The vectorized binary operation value of the two values -template -inline unsigned vectorized_with_pred(unsigned a, unsigned b, - const BinaryOperation binary_op, - bool *pred_hi, bool *pred_lo) { - static_assert(std::is_same_v || - std::is_same_v); - auto v1 = sycl::vec(a).as>(); - auto v2 = sycl::vec(b).as>(); - sycl::vec ret; +template +inline unsigned vectorized_binary_with_pred(unsigned a, unsigned b, + const BinaryOperation binary_op, + bool *pred_hi, bool *pred_lo) { + auto v1 = sycl::vec(a).as(); + auto v2 = sycl::vec(b).as(); + VecT ret; ret[0] = binary_op(v1[0], v2[0], pred_lo); ret[1] = binary_op(v1[1], v2[1], pred_hi); return ret.template as>(); diff --git a/sycl/test-e2e/syclcompat/math/math_fixt.hpp b/sycl/test-e2e/syclcompat/math/math_fixt.hpp index 7ba88cf1451a3..4647142da6c61 100644 --- a/sycl/test-e2e/syclcompat/math/math_fixt.hpp +++ b/sycl/test-e2e/syclcompat/math/math_fixt.hpp @@ -260,14 +260,14 @@ class TernaryOpTestLauncher : OpTestLauncher { TernaryOpTestLauncher(const syclcompat::dim3 &grid, const syclcompat::dim3 &threads, const size_t data_size = 1) - : OpTestLauncher{ - grid, threads, data_size, - should_skip()(syclcompat::get_current_device())} { + : OpTestLauncher{grid, threads, data_size, + should_skip()( + syclcompat::get_current_device())} { if (skip_) return; op1_ = syclcompat::malloc(data_size); op2_ = syclcompat::malloc(data_size); - op3_ = syclcompat::malloc(data_size); + op3_ = syclcompat::malloc(data_size); res_ = syclcompat::malloc(data_size); }; @@ -281,13 +281,13 @@ class TernaryOpTestLauncher : OpTestLauncher { } template - void launch_test(ValueT op1, ValueU op2, ValueU op3, ResultT expected, + void launch_test(ValueT op1, ValueU op2, ValueV op3, ResultT expected, bool need_relu = false) { if (skip_) return; syclcompat::memcpy(op1_, &op1, data_size_); syclcompat::memcpy(op2_, &op2, data_size_); - syclcompat::memcpy(op3_, &op3, data_size_); + syclcompat::memcpy(op3_, &op3, data_size_); syclcompat::launch(grid_, threads_, op1_, op2_, op3_, res_, need_relu); syclcompat::wait(); diff --git a/sycl/test-e2e/syclcompat/math/math_ops.cpp b/sycl/test-e2e/syclcompat/math/math_ops.cpp index d52d9c60d8ded..baaa0f7f15210 100644 --- a/sycl/test-e2e/syclcompat/math/math_ops.cpp +++ b/sycl/test-e2e/syclcompat/math/math_ops.cpp @@ -226,8 +226,10 @@ template void test_syclcompat_relu() { UnaryOpTestLauncher(grid, threads) .template launch_test>(op1, res1); - const ValueT op2 = static_cast(-3); - const ValueT res2 = static_cast(0); + const ValueT op2 = std::is_signed_v ? static_cast(-3) + : static_cast(2); + const ValueT res2 = std::is_signed_v ? static_cast(0) + : static_cast(2); UnaryOpTestLauncher(grid, threads) .template launch_test>(op2, res2); @@ -374,7 +376,7 @@ int main() { test_syclcompat_pow(); test_syclcompat_pow(); - INSTANTIATE_ALL_TYPES(fp_type_list, test_syclcompat_relu); + INSTANTIATE_ALL_TYPES(value_type_list, test_syclcompat_relu); INSTANTIATE_ALL_TYPES(fp_type_list_no_bfloat16, test_syclcompat_cbrt); INSTANTIATE_ALL_CONTAINER_TYPES(fp_type_list, sycl::vec, test_isnan); diff --git a/sycl/test-e2e/syclcompat/math/math_vectorized.cpp b/sycl/test-e2e/syclcompat/math/math_vectorized.cpp index b2257359d1b58..9c57c88ce445b 100644 --- a/sycl/test-e2e/syclcompat/math/math_vectorized.cpp +++ b/sycl/test-e2e/syclcompat/math/math_vectorized.cpp @@ -31,158 +31,177 @@ #include "math_fixt.hpp" template -void vectorized_binary_kernel(ValueT *a, ValueT *b, unsigned *r, +void vectorized_binary_kernel(unsigned *a, unsigned *b, unsigned *r, bool need_relu) { - unsigned ua = static_cast(*a); - unsigned ub = static_cast(*b); - *r = syclcompat::vectorized_binary(ua, ub, BinaryOp(), - need_relu); + *r = syclcompat::vectorized_binary(*a, *b, BinaryOp(), need_relu); } template -void test_vectorized_binary(ValueT op1, ValueT op2, unsigned expected, +void test_vectorized_binary(unsigned op1, unsigned op2, unsigned expected, bool need_relu = false) { std::cout << __PRETTY_FUNCTION__ << std::endl; constexpr syclcompat::dim3 grid{1}; constexpr syclcompat::dim3 threads{1}; - BinaryOpTestLauncher(grid, threads) + BinaryOpTestLauncher(grid, threads) .template launch_test>( op1, op2, expected, need_relu); } template -void vectorized_unary_kernel(ValueT *a, unsigned *r) { - unsigned ua = static_cast(*a); - *r = syclcompat::vectorized_unary(ua, UnaryOp()); +void vectorized_unary_kernel(unsigned *a, unsigned *r) { + *r = syclcompat::vectorized_unary(*a, UnaryOp()); } template -void test_vectorized_unary(ValueT op1, unsigned expected) { +void test_vectorized_unary(unsigned op1, unsigned expected) { std::cout << __PRETTY_FUNCTION__ << std::endl; constexpr syclcompat::dim3 grid{1}; constexpr syclcompat::dim3 threads{1}; - UnaryOpTestLauncher(grid, threads) + UnaryOpTestLauncher(grid, threads) .template launch_test>(op1, expected); } template -void vectorized_sum_abs_diff_kernel(ValueT *a, ValueT *b, unsigned *r) { - unsigned ua = static_cast(*a); - unsigned ub = static_cast(*b); - - *r = syclcompat::vectorized_sum_abs_diff(ua, ub); +void vectorized_sum_abs_diff_kernel(unsigned *a, unsigned *b, unsigned *r) { + *r = syclcompat::vectorized_sum_abs_diff(*a, *b); } template -void test_vectorized_sum_abs_diff(ValueT op1, ValueT op2, unsigned expected) { +void test_vectorized_sum_abs_diff(unsigned op1, unsigned op2, + unsigned expected) { std::cout << __PRETTY_FUNCTION__ << std::endl; constexpr syclcompat::dim3 grid{1}; constexpr syclcompat::dim3 threads{1}; - BinaryOpTestLauncher(grid, threads) + BinaryOpTestLauncher(grid, threads) .template launch_test>(op1, op2, expected); } template -void vectorized_ternary_kernel(ValueT *a, ValueT *b, ValueT *c, unsigned *r, - bool need_relu) { - unsigned ua = static_cast(*a); - unsigned ub = static_cast(*b); - unsigned uc = static_cast(*c); - *r = syclcompat::vectorized_ternary(ua, ub, uc, BinaryOp1(), - BinaryOp2(), need_relu); +void vectorized_ternary_kernel(unsigned *a, unsigned *b, unsigned *c, + unsigned *r, bool need_relu) { + *r = syclcompat::vectorized_ternary(*a, *b, *c, BinaryOp1(), + BinaryOp2(), need_relu); } template -void test_vectorized_ternary(ValueT op1, ValueT op2, ValueT op3, +void test_vectorized_ternary(unsigned op1, unsigned op2, unsigned op3, unsigned expected, bool need_relu = false) { std::cout << __PRETTY_FUNCTION__ << std::endl; constexpr syclcompat::dim3 grid{1}; constexpr syclcompat::dim3 threads{1}; - TernaryOpTestLauncher(grid, threads) + TernaryOpTestLauncher(grid, threads) .template launch_test< vectorized_ternary_kernel>( op1, op2, op3, expected, need_relu); } template -void vectorized_with_pred_kernel(ValueT *a, ValueT *b, unsigned *r, - bool *pred_hi, bool *pred_lo) { - unsigned ua = static_cast(*a); - unsigned ub = static_cast(*b); - - *r = syclcompat::vectorized_with_pred(ua, ub, BinaryOp(), pred_hi, - pred_lo); +void vectorized_binary_with_pred_kernel(unsigned *a, unsigned *b, unsigned *r, + bool *pred_hi, bool *pred_lo) { + *r = syclcompat::vectorized_binary_with_pred(*a, *b, BinaryOp(), + pred_hi, pred_lo); } template -void test_vectorized_with_pred(ValueT op1, ValueT op2, unsigned expected, - bool expected_hi, bool expected_lo) { +void test_vectorized_binary_with_pred(unsigned op1, unsigned op2, + unsigned expected, bool pred_hi, + bool pred_lo) { std::cout << __PRETTY_FUNCTION__ << std::endl; constexpr syclcompat::dim3 grid{1}; constexpr syclcompat::dim3 threads{1}; - BinaryOpTestLauncher(grid, threads) - .template launch_test>( - op1, op2, expected, expected_hi, expected_lo); + BinaryOpTestLauncher(grid, threads) + .template launch_test< + vectorized_binary_with_pred_kernel>( + op1, op2, expected, pred_hi, pred_lo); } int main() { - test_vectorized_binary(0x00010002, 0x00040002, - 0x00030000); - test_vectorized_binary(0x00020002, 0xFFFDFFFF, - 0xFFFF0001); - test_vectorized_binary(0x00010008, 0x00020001, - 0x00020005); - test_vectorized_binary(0x00010003, 0x00020005, - 0x00010004); - test_vectorized_binary(0x0FFF0000, 0x00000FFF, - 0x0FFF0FFF); - test_vectorized_binary(0x0FFF0000, 0x00000FFF, - 0x00000000); - test_vectorized_binary(0xFFFB0005, 0x00030008, - 0xFFF8FFFD); - test_vectorized_binary(0x00010002, 0x00040002, - 0x00030000, true); - test_vectorized_binary(0x00020002, 0xFFFDFFFF, - 0x00000001, true); - test_vectorized_binary(0x00010008, 0x00020001, - 0x00020005, true); - test_vectorized_binary(0x00010003, 0x00020005, - 0x00010004, true); - test_vectorized_binary(0x0FFF0000, 0x00000FFF, - 0x0FFF0FFF, true); - test_vectorized_binary(0x0FFF0000, 0x00000FFF, - 0x00000000, true); - test_vectorized_binary(0xFFFB0005, 0x00030008, - 0x00000000, true); - test_vectorized_unary(0xFFFBFFFD, 0x00050003); - test_vectorized_sum_abs_diff(0x00010002, 0x00040002, 0x00000003); - test_vectorized_ternary, syclcompat::maximum, uint32_t>( + test_vectorized_binary( + 0x00010002, 0x00040002, 0x00030000); + test_vectorized_binary( + 0x00020002, 0xFFFDFFFF, 0xFFFF0001); + test_vectorized_binary( + 0x00010008, 0x00020001, 0x00020005); + test_vectorized_binary(0x00010003, 0x00020005, + 0x00010004); + test_vectorized_binary( + 0x0FFF0000, 0x00000FFF, 0x0FFF0FFF); + test_vectorized_binary( + 0x0FFF0000, 0x00000FFF, 0x00000000); + test_vectorized_binary( + 0xFFFB0005, 0x00030008, 0xFFF8FFFD); + test_vectorized_binary( + 0x00010002, 0x00040002, 0x00030000, true); + test_vectorized_binary( + 0x00020002, 0xFFFDFFFF, 0x00000001, true); + test_vectorized_binary( + 0x00010008, 0x00020001, 0x00020005, true); + test_vectorized_binary(0x00010003, 0x00020005, + 0x00010004, true); + test_vectorized_binary( + 0x0FFF0000, 0x00000FFF, 0x0FFF0FFF, true); + test_vectorized_binary( + 0x0FFF0000, 0x00000FFF, 0x00000000, true); + test_vectorized_binary( + 0xFFFB0005, 0x00030008, 0x00000000, true); + test_vectorized_unary(0xFFFBFFFD, 0x00050003); + test_vectorized_sum_abs_diff(0x00010002, 0x00040002, + 0x00000003); + test_vectorized_ternary, syclcompat::maximum, sycl::ushort2>( 0x00010002, 0x00040002, 0x00080004, 0x00080004); - test_vectorized_ternary, syclcompat::maximum, uint32_t>( + test_vectorized_ternary, syclcompat::maximum, sycl::ushort2>( 0x00010002, 0x00040002, 0x00080004, 0x00080004, true); - test_vectorized_ternary, syclcompat::minimum, uint32_t>( + test_vectorized_ternary, syclcompat::minimum, sycl::ushort2>( 0x00010002, 0x00040002, 0x00080004, 0x00050004); - test_vectorized_ternary, syclcompat::minimum, uint32_t>( + test_vectorized_ternary, syclcompat::minimum, sycl::ushort2>( 0x00010002, 0x00040002, 0x00080004, 0x00050004, true); - test_vectorized_ternary( - 0x00010002, 0x00040002, 0x00080004, 0x00080004); - test_vectorized_ternary( - 0x00010002, 0x00040002, 0x00080004, 0x00080004, true); - test_vectorized_ternary( - 0x00010002, 0x00040002, 0x00080004, 0x00010002); - test_vectorized_ternary( - 0x00010002, 0x00040002, 0x00080004, 0x00010002, true); - test_vectorized_with_pred( - 0x00010002, 0x00040002, 0x00040002, false, true); - test_vectorized_with_pred( - 0x00010002, 0x00040002, 0x00010002, true, true); + test_vectorized_ternary(0x00010002, 0x00040002, 0x00080004, + 0x00080004); + test_vectorized_ternary(0x00010002, 0x00040002, 0x00080004, + 0x00080004, true); + test_vectorized_ternary(0x00010002, 0x00040002, 0x00080004, + 0x00010002); + test_vectorized_ternary(0x00010002, 0x00040002, 0x00080004, + 0x00010002, true); + test_vectorized_ternary, syclcompat::maximum, sycl::short2>( + 0x80010002, 0x00040002, 0x00080004, 0x00080004); + test_vectorized_ternary, syclcompat::maximum, sycl::short2>( + 0x80010002, 0x00040002, 0x00080004, 0x00080004, true); + test_vectorized_ternary, syclcompat::minimum, sycl::short2>( + 0x80010002, 0x00040002, 0x00080004, 0x80050004); + test_vectorized_ternary, syclcompat::minimum, sycl::short2>( + 0x80010002, 0x00040002, 0x00080004, 0x00000004, true); + test_vectorized_ternary(0x80010002, 0x00040002, 0x00080004, + 0x00080004); + test_vectorized_ternary(0x80010002, 0x00040002, 0x00080004, + 0x00080004, true); + test_vectorized_ternary(0x80010002, 0x00040002, 0x00080004, + 0x80010002); + test_vectorized_ternary(0x80010002, 0x00040002, 0x00080004, + 0x00000002, true); + test_vectorized_binary_with_pred( + 0x80010002, 0x00040002, 0x00040002, false, true); + test_vectorized_binary_with_pred( + 0x80010002, 0x00040002, 0x80010002, true, true); + test_vectorized_binary_with_pred( + 0x80010002, 0x00040002, 0x80010002, true, true); + test_vectorized_binary_with_pred( + 0x80010002, 0x00040002, 0x00040002, false, true); return 0; }