@@ -119,34 +119,13 @@ class vectorized_binary {
119119 }
120120};
121121
122- // Vectorized_binary for logical operations
123122template <typename VecT, class BinaryOperation >
124123class vectorized_binary <
125124 VecT, BinaryOperation,
126- std::enable_if_t <std::is_same_v<
127- bool , decltype (std::declval<BinaryOperation>()(
128- std::declval<typename VecT::element_type>(),
129- std::declval<typename VecT::element_type>()))>>> {
125+ std::void_t <std::invoke_result_t <BinaryOperation, VecT, VecT>>> {
130126public:
131127 inline VecT operator ()(VecT a, VecT b, const BinaryOperation binary_op) {
132- unsigned result = 0 ;
133- constexpr size_t elem_size = 8 * sizeof (typename VecT::element_type);
134- static_assert (elem_size < 32 ,
135- " Vector element size must be less than 4 bytes" );
136- constexpr unsigned bool_mask = (1U << elem_size) - 1 ;
137-
138- for (size_t i = 0 ; i < a.size (); ++i) {
139- bool comp_result = binary_op (a[i], b[i]);
140- result |= (comp_result ? bool_mask : 0U ) << (i * elem_size);
141- }
142-
143- VecT v4;
144- for (size_t i = 0 ; i < v4.size (); ++i) {
145- v4[i] = static_cast <typename VecT::element_type>(
146- (result >> (i * elem_size)) & bool_mask);
147- }
148-
149- return v4;
128+ return binary_op (a, b).template as <VecT>();
150129 }
151130};
152131
@@ -694,8 +673,9 @@ inline unsigned vectorized_unary(unsigned a, const UnaryOperation unary_op) {
694673template <typename VecT>
695674inline unsigned vectorized_sum_abs_diff (unsigned a, unsigned b) {
696675 sycl::vec<unsigned , 1 > v0{a}, v1{b};
697- auto v2 = v0.as <VecT>();
698- auto v3 = v1.as <VecT>();
676+ // Need convert element type to wider signed type to avoid overflow.
677+ auto v2 = v0.as <VecT>().template convert <int >();
678+ auto v3 = v1.as <VecT>().template convert <int >();
699679 auto v4 = sycl::abs_diff (v2, v3);
700680 unsigned sum = 0 ;
701681 for (size_t i = 0 ; i < v4.size (); ++i) {
@@ -1095,13 +1075,8 @@ inline unsigned vectorized_binary(unsigned a, unsigned b,
10951075 auto v3 = v1.as <VecT>();
10961076 auto v4 =
10971077 detail::vectorized_binary<VecT, BinaryOperation>()(v2, v3, binary_op);
1098- if constexpr (!std::is_same_v<
1099- bool , decltype (std::declval<BinaryOperation>()(
1100- std::declval<typename VecT::element_type>(),
1101- std::declval<typename VecT::element_type>()))>) {
1102- if (need_relu)
1103- v4 = relu (v4);
1104- }
1078+ if (need_relu)
1079+ v4 = relu (v4);
11051080 v0 = v4.template as <sycl::vec<unsigned , 1 >>();
11061081 return v0;
11071082}
0 commit comments