Skip to content

Commit c13b071

Browse files
[SYCL][COMPAT] Extended vectorized_binary support to logical operators (#15759)
We add support for logical operators with `vectorized_binary` as well as the relevant unit-tests.
1 parent 730cd3a commit c13b071

File tree

2 files changed

+114
-4
lines changed

2 files changed

+114
-4
lines changed

sycl/include/syclcompat/math.hpp

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,37 @@ class vectorized_binary {
118118
}
119119
};
120120

121+
// Vectorized_binary for logical operations
122+
template <typename VecT, class BinaryOperation>
123+
class vectorized_binary<
124+
VecT, BinaryOperation,
125+
std::enable_if_t<std::is_same_v<
126+
bool, decltype(std::declval<BinaryOperation>()(
127+
std::declval<typename VecT::element_type>(),
128+
std::declval<typename VecT::element_type>()))>>> {
129+
public:
130+
inline VecT operator()(VecT a, VecT b, const BinaryOperation binary_op) {
131+
unsigned result = 0;
132+
constexpr size_t elem_size = 8 * sizeof(typename VecT::element_type);
133+
static_assert(elem_size < 32,
134+
"Vector element size must be less than 4 bytes");
135+
constexpr unsigned bool_mask = (1U << elem_size) - 1;
136+
137+
for (size_t i = 0; i < a.size(); ++i) {
138+
bool comp_result = binary_op(a[i], b[i]);
139+
result |= (comp_result ? bool_mask : 0U) << (i * elem_size);
140+
}
141+
142+
VecT v4;
143+
for (size_t i = 0; i < v4.size(); ++i) {
144+
v4[i] = static_cast<typename VecT::element_type>(
145+
(result >> (i * elem_size)) & bool_mask);
146+
}
147+
148+
return v4;
149+
}
150+
};
151+
121152
/// Extend the 'val' to 'bit' size, zero extend for unsigned int and signed
122153
/// extend for signed int. Returns a signed integer type.
123154
template <typename ValueT>
@@ -1040,7 +1071,7 @@ struct average {
10401071

10411072
} // namespace detail
10421073

1043-
/// Compute vectorized binary operation value for two values, with each value
1074+
/// Compute vectorized binary operation value for two/four values, with each
10441075
/// treated as a vector type \p VecT.
10451076
/// \tparam [in] VecT The type of the vector
10461077
/// \tparam [in] BinaryOperation The binary operation class
@@ -1052,14 +1083,19 @@ struct average {
10521083
template <typename VecT, class BinaryOperation>
10531084
inline unsigned vectorized_binary(unsigned a, unsigned b,
10541085
const BinaryOperation binary_op,
1055-
bool need_relu = false) {
1086+
[[maybe_unused]] bool need_relu = false) {
10561087
sycl::vec<unsigned, 1> v0{a}, v1{b};
10571088
auto v2 = v0.as<VecT>();
10581089
auto v3 = v1.as<VecT>();
10591090
auto v4 =
10601091
detail::vectorized_binary<VecT, BinaryOperation>()(v2, v3, binary_op);
1061-
if (need_relu)
1062-
v4 = relu(v4);
1092+
if constexpr (!std::is_same_v<
1093+
bool, decltype(std::declval<BinaryOperation>()(
1094+
std::declval<typename VecT::element_type>(),
1095+
std::declval<typename VecT::element_type>()))>) {
1096+
if (need_relu)
1097+
v4 = relu(v4);
1098+
}
10631099
v0 = v4.template as<sycl::vec<unsigned, 1>>();
10641100
return v0;
10651101
}

sycl/test-e2e/syclcompat/math/math_vectorized.cpp

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,18 @@ void test_vectorized_binary(unsigned op1, unsigned op2, unsigned expected,
4848
op1, op2, expected, need_relu);
4949
}
5050

51+
template <typename BinaryOp, typename ValueT>
52+
void test_vectorized_binary_logical(unsigned op1, unsigned op2,
53+
unsigned expected) {
54+
std::cout << __PRETTY_FUNCTION__ << std::endl;
55+
constexpr syclcompat::dim3 grid{1};
56+
constexpr syclcompat::dim3 threads{1};
57+
58+
BinaryOpTestLauncher<unsigned, unsigned, unsigned>(grid, threads)
59+
.template launch_test<vectorized_binary_kernel<BinaryOp, ValueT>>(
60+
op1, op2, expected, false);
61+
}
62+
5163
template <typename UnaryOp, typename ValueT>
5264
void vectorized_unary_kernel(unsigned *a, unsigned *r) {
5365
*r = syclcompat::vectorized_unary<ValueT>(*a, UnaryOp());
@@ -203,5 +215,67 @@ int main() {
203215
test_vectorized_binary_with_pred<syclcompat::minimum, sycl::ushort2>(
204216
0x80010002, 0x00040002, 0x00040002, false, true);
205217

218+
// Logical Binary Operators v2
219+
test_vectorized_binary_logical<std::equal_to<>, sycl::short2>(
220+
0xFFF00002, 0xFFF00001, 0xFFFF0000);
221+
test_vectorized_binary_logical<std::equal_to<>, sycl::short2>(
222+
0x0001F00F, 0x0003F00F, 0x0000FFFF);
223+
224+
test_vectorized_binary_logical<std::not_equal_to<>, sycl::short2>(
225+
0xFFF00002, 0xFFF00001, 0x0000FFFF);
226+
test_vectorized_binary_logical<std::not_equal_to<>, sycl::short2>(
227+
0x0001F00F, 0x0003F00F, 0xFFFF0000);
228+
229+
test_vectorized_binary_logical<std::greater_equal<>, sycl::short2>(
230+
0xFFF00002, 0xFFF00001, 0xFFFFFFFF);
231+
test_vectorized_binary_logical<std::greater_equal<>, sycl::short2>(
232+
0x0001F00F, 0x0003F001, 0x0000FFFF);
233+
234+
test_vectorized_binary_logical<std::greater<>, sycl::short2>(
235+
0xFFF00002, 0xFFF00001, 0x0000FFFF);
236+
test_vectorized_binary_logical<std::greater<>, sycl::short2>(
237+
0x0003F00F, 0x0001F00F, 0xFFFF0000);
238+
239+
test_vectorized_binary_logical<std::less_equal<>, sycl::short2>(
240+
0xFFF00001, 0xF0F00002, 0x0000FFFF);
241+
test_vectorized_binary_logical<std::less_equal<>, sycl::short2>(
242+
0x0001FF0F, 0x0003F00F, 0xFFFF0000);
243+
244+
test_vectorized_binary_logical<std::less<>, sycl::short2>(
245+
0xFFF00001, 0xFFF00002, 0x0000FFFF);
246+
test_vectorized_binary_logical<std::less<>, sycl::short2>(
247+
0x0001F00F, 0x0003F00F, 0xFFFF0000);
248+
249+
// Logical Binary Operators v4
250+
test_vectorized_binary_logical<std::equal_to<>, sycl::uchar4>(
251+
0x0001F00F, 0x0003F00F, 0xFF00FFFF);
252+
test_vectorized_binary_logical<std::equal_to<>, sycl::uchar4>(
253+
0x0102F0F0, 0x0202F0FF, 0x00FFFF00);
254+
255+
test_vectorized_binary_logical<std::not_equal_to<>, sycl::uchar4>(
256+
0x0001F00F, 0xFF01F10F, 0xFF00FF00);
257+
test_vectorized_binary_logical<std::not_equal_to<>, sycl::uchar4>(
258+
0x0201F0F0, 0x0202F0FF, 0x00FF00FF);
259+
260+
test_vectorized_binary_logical<std::greater_equal<>, sycl::uchar4>(
261+
0xFFF00002, 0xFFF10101, 0xFF0000FF);
262+
test_vectorized_binary_logical<std::greater_equal<>, sycl::uchar4>(
263+
0x0001F1F0, 0x0103F001, 0x0000FFFF);
264+
265+
test_vectorized_binary_logical<std::greater<>, sycl::uchar4>(
266+
0xFFF00002, 0xF0F00001, 0xFF0000FF);
267+
test_vectorized_binary_logical<std::greater<>, sycl::uchar4>(
268+
0x0103F0F1, 0x0102F0F0, 0x00FF00FF);
269+
270+
test_vectorized_binary_logical<std::less_equal<>, sycl::uchar4>(
271+
0xFFF10001, 0xFFF00100, 0xFF00FF00);
272+
test_vectorized_binary_logical<std::less_equal<>, sycl::uchar4>(
273+
0x0101F1F0, 0x0003F0F1, 0x00FF00FF);
274+
275+
test_vectorized_binary_logical<std::less<>, sycl::uchar4>(
276+
0xFFF10001, 0xFFF20100, 0x00FFFF00);
277+
test_vectorized_binary_logical<std::less<>, sycl::uchar4>(
278+
0x0101F1F0, 0x0102F1F1, 0x00FF00FF);
279+
206280
return 0;
207281
}

0 commit comments

Comments
 (0)