@@ -856,23 +856,24 @@ pow(const ValueT a, const ValueU b) {
856856// / Performs relu saturation.
857857// / \param [in] a The input value
858858// / \returns the relu saturation result
859- template <typename ValueT>
860- inline std::enable_if_t <syclcompat::is_floating_point_v<ValueT>, ValueT>
861- relu (const ValueT a) {
862- if (!detail::isnan (a) && a < ValueT (0 ))
859+ template <typename ValueT> inline ValueT relu (const ValueT a) {
860+ if constexpr (syclcompat::is_floating_point_v<ValueT>)
861+ if (detail::isnan (a))
862+ return a;
863+ if (a < ValueT (0 ))
863864 return ValueT (0 );
864865 return a;
865866}
866- template <class ValueT >
867- inline std::enable_if_t <syclcompat::is_floating_point_v<ValueT>,
868- sycl::vec<ValueT, 2 >>
869- relu (const sycl::vec<ValueT, 2 > a) {
870- return {relu (a[0 ]), relu (a[1 ])};
867+ template <class ValueT , int NumElements>
868+ inline sycl::vec<ValueT, NumElements>
869+ relu (const sycl::vec<ValueT, NumElements> a) {
870+ sycl::vec<ValueT, NumElements> ret;
871+ for (int i = 0 ; i < NumElements; ++i)
872+ ret[i] = relu (a[i]);
873+ return ret;
871874}
872875template <class ValueT >
873- inline std::enable_if_t <syclcompat::is_floating_point_v<ValueT>,
874- sycl::marray<ValueT, 2 >>
875- relu (const sycl::marray<ValueT, 2 > a) {
876+ inline sycl::marray<ValueT, 2 > relu (const sycl::marray<ValueT, 2 > a) {
876877 return {relu (a[0 ]), relu (a[1 ])};
877878}
878879
@@ -990,6 +991,10 @@ struct maximum {
990991 auto operator ()(const ValueT x, const ValueT y) const {
991992 return sycl::max (x, y);
992993 }
994+ template <typename ValueT>
995+ auto operator ()(const ValueT x, const ValueT y, bool *pred) const {
996+ return (x >= y) ? ((*pred = true ), x) : ((*pred = false ), y);
997+ }
993998};
994999
9951000// / A sycl::min wrapper functors.
@@ -998,6 +1003,10 @@ struct minimum {
9981003 auto operator ()(const ValueT x, const ValueT y) const {
9991004 return sycl::min (x, y);
10001005 }
1006+ template <typename ValueT>
1007+ auto operator ()(const ValueT x, const ValueT y, bool *pred) const {
1008+ return (x <= y) ? ((*pred = true ), x) : ((*pred = false ), y);
1009+ }
10011010};
10021011
10031012// / A sycl::sub_sat wrapper functors.
@@ -1037,19 +1046,76 @@ struct average {
10371046// / \tparam [in] BinaryOperation The binary operation class
10381047// / \param [in] a The first value
10391048// / \param [in] b The second value
1049+ // / \param [in] binary_op The operation to do with the two values
1050+ // / \param [in] need_relu Whether the result need relu saturation
10401051// / \returns The vectorized binary operation value of the two values
10411052template <typename VecT, class BinaryOperation >
10421053inline unsigned vectorized_binary (unsigned a, unsigned b,
1043- const BinaryOperation binary_op) {
1054+ const BinaryOperation binary_op,
1055+ bool need_relu = false ) {
10441056 sycl::vec<unsigned , 1 > v0{a}, v1{b};
10451057 auto v2 = v0.as <VecT>();
10461058 auto v3 = v1.as <VecT>();
10471059 auto v4 =
10481060 detail::vectorized_binary<VecT, BinaryOperation>()(v2, v3, binary_op);
1061+ if (need_relu)
1062+ v4 = relu (v4);
10491063 v0 = v4.template as <sycl::vec<unsigned , 1 >>();
10501064 return v0;
10511065}
10521066
1067+ // / Compute two vectorized binary operation value with pred for three values,
1068+ // / with each value treated as a 2 \p T type elements vector type.
1069+ // /
1070+ // / \tparam [in] VecT The type of the vector
1071+ // / \tparam [in] BinaryOperation1 The first binary operation class
1072+ // / \tparam [in] BinaryOperation2 The second binary operation class
1073+ // / \param [in] a The first value
1074+ // / \param [in] b The second value
1075+ // / \param [in] c The third value
1076+ // / \param [in] binary_op1 The first operation to do with the first two values
1077+ // / \param [in] binary_op2 The second operation to do with the third values
1078+ // / \param [in] need_relu Whether the result need relu saturation
1079+ // / \returns The two vectorized binary operation value of the three values
1080+ template <typename VecT, typename BinaryOperation1, typename BinaryOperation2>
1081+ inline unsigned vectorized_ternary (unsigned a, unsigned b, unsigned c,
1082+ const BinaryOperation1 binary_op1,
1083+ const BinaryOperation2 binary_op2,
1084+ bool need_relu = false ) {
1085+ const auto v1 = sycl::vec<unsigned , 1 >(a).as <VecT>();
1086+ const auto v2 = sycl::vec<unsigned , 1 >(b).as <VecT>();
1087+ const auto v3 = sycl::vec<unsigned , 1 >(c).as <VecT>();
1088+ auto v4 =
1089+ detail::vectorized_binary<VecT, BinaryOperation1>()(v1, v2, binary_op1);
1090+ v4 = detail::vectorized_binary<VecT, BinaryOperation2>()(v4, v3, binary_op2);
1091+ if (need_relu)
1092+ v4 = relu (v4);
1093+ return v4.template as <sycl::vec<unsigned , 1 >>();
1094+ }
1095+
1096+ // / Compute vectorized binary operation value with pred for two values, with
1097+ // / each value treated as a 2 \p T type elements vector type.
1098+ // /
1099+ // / \tparam [in] VecT The type of the vector
1100+ // / \tparam [in] BinaryOperation The binary operation class
1101+ // / \param [in] a The first value
1102+ // / \param [in] b The second value
1103+ // / \param [in] binary_op The operation with pred to do with the two values
1104+ // / \param [out] pred_hi The pred pointer that pass into high halfword operation
1105+ // / \param [out] pred_lo The pred pointer that pass into low halfword operation
1106+ // / \returns The vectorized binary operation value of the two values
1107+ template <typename VecT, typename BinaryOperation>
1108+ inline unsigned vectorized_binary_with_pred (unsigned a, unsigned b,
1109+ const BinaryOperation binary_op,
1110+ bool *pred_hi, bool *pred_lo) {
1111+ auto v1 = sycl::vec<unsigned , 1 >(a).as <VecT>();
1112+ auto v2 = sycl::vec<unsigned , 1 >(b).as <VecT>();
1113+ VecT ret;
1114+ ret[0 ] = binary_op (v1[0 ], v2[0 ], pred_lo);
1115+ ret[1 ] = binary_op (v1[1 ], v2[1 ], pred_hi);
1116+ return ret.template as <sycl::vec<unsigned , 1 >>();
1117+ }
1118+
10531119template <typename T1, typename T2>
10541120using dot_product_acc_t =
10551121 std::conditional_t <std::is_unsigned_v<T1> && std::is_unsigned_v<T2>,
0 commit comments