Skip to content

Commit 05b0748

Browse files
committed
Generalize compare_mask & unordered_compare_mask
- compare_mask & unordered_compare_mask support sycl::marray - sycl::marray tests for all `compare` APIs
1 parent d16dd56 commit 05b0748

File tree

3 files changed

+41
-30
lines changed

3 files changed

+41
-30
lines changed

sycl/doc/syclcompat/README.md

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1856,14 +1856,13 @@ unordered_compare_both(const ValueT a, const ValueT b,
18561856
const BinaryOperation binary_op);
18571857
18581858
template <typename ValueT, class BinaryOperation>
1859-
inline unsigned compare_mask(const sycl::vec<ValueT, 2> a,
1860-
const sycl::vec<ValueT, 2> b,
1861-
const BinaryOperation binary_op);
1859+
inline std::enable_if_t<ValueT::size() == 2, unsigned>
1860+
compare_mask(const ValueT a, const ValueT b, const BinaryOperation binary_op);
18621861
18631862
template <typename ValueT, class BinaryOperation>
1864-
inline unsigned unordered_compare_mask(const sycl::vec<ValueT, 2> a,
1865-
const sycl::vec<ValueT, 2> b,
1866-
const BinaryOperation binary_op);
1863+
inline std::enable_if_t<ValueT::size() == 2, unsigned>
1864+
unordered_compare_mask(const ValueT a, const ValueT b,
1865+
const BinaryOperation binary_op);
18671866
18681867
template <typename S, typename T> inline T vectorized_max(T a, T b);
18691868

sycl/include/syclcompat/math.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#pragma once
3333

3434
#include <sycl/feature_test.hpp>
35+
#include <type_traits>
3536

3637
// TODO(syclcompat-lib-reviewers): this should not be required
3738
#ifndef SYCL_EXT_ONEAPI_COMPLEX
@@ -578,9 +579,8 @@ unordered_compare_both(const ValueT a, const ValueT b,
578579
/// \param [in] binary_op functor that implements the binary operation
579580
/// \returns the comparison result
580581
template <typename ValueT, class BinaryOperation>
581-
inline unsigned compare_mask(const sycl::vec<ValueT, 2> a,
582-
const sycl::vec<ValueT, 2> b,
583-
const BinaryOperation binary_op) {
582+
inline std::enable_if_t<ValueT::size() == 2, unsigned>
583+
compare_mask(const ValueT a, const ValueT b, const BinaryOperation binary_op) {
584584
// Since compare returns 0 or 1, -compare will be 0x00000000 or 0xFFFFFFFF
585585
return ((-compare(a[0], b[0], binary_op)) << 16) |
586586
((-compare(a[1], b[1], binary_op)) & 0xFFFF);
@@ -594,9 +594,9 @@ inline unsigned compare_mask(const sycl::vec<ValueT, 2> a,
594594
/// \param [in] binary_op functor that implements the binary operation
595595
/// \returns the comparison result
596596
template <typename ValueT, class BinaryOperation>
597-
inline unsigned unordered_compare_mask(const sycl::vec<ValueT, 2> a,
598-
const sycl::vec<ValueT, 2> b,
599-
const BinaryOperation binary_op) {
597+
inline std::enable_if_t<ValueT::size() == 2, unsigned>
598+
unordered_compare_mask(const ValueT a, const ValueT b,
599+
const BinaryOperation binary_op) {
600600
return ((-unordered_compare(a[0], b[0], binary_op)) << 16) |
601601
((-unordered_compare(a[1], b[1], binary_op)) & 0xFFFF);
602602
}

sycl/test-e2e/syclcompat/math/math_compare.cpp

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,10 @@ void compare_not_equal_vec_kernel(Container *a, Container *b, Container *r) {
9696
*r = syclcompat::compare(*a, *b, std::not_equal_to<>());
9797
}
9898

99-
template <typename ValueT> void test_compare_vec() {
99+
template <template <typename T, int Dim> typename ContainerT,
100+
typename ValueT> void test_compare_vec() {
100101
std::cout << __PRETTY_FUNCTION__ << std::endl;
101-
using Container = sycl::vec<ValueT, 2>;
102+
using Container = ContainerT<ValueT, 2>;
102103

103104
constexpr syclcompat::dim3 grid{1};
104105
constexpr syclcompat::dim3 threads{1};
@@ -177,9 +178,10 @@ void unordered_compare_not_equal_vec_kernel(Container *a, Container *b,
177178
*r = syclcompat::unordered_compare(*a, *b, std::not_equal_to<>());
178179
}
179180

180-
template <typename ValueT> void test_unordered_compare_vec() {
181+
template <template <typename T, int Dim> typename ContainerT,
182+
typename ValueT> void test_unordered_compare_vec() {
181183
std::cout << __PRETTY_FUNCTION__ << std::endl;
182-
using Container = sycl::vec<ValueT, 2>;
184+
using Container = ContainerT<ValueT, 2>;
183185

184186
constexpr syclcompat::dim3 grid{1};
185187
constexpr syclcompat::dim3 threads{1};
@@ -207,9 +209,10 @@ void compare_both_kernel(Container *a, Container *b, bool *r) {
207209
*r = syclcompat::compare_both(*a, *b, std::equal_to<>());
208210
}
209211

210-
template <typename ValueT> void test_compare_both() {
212+
template <template <typename T, int Dim> typename ContainerT,
213+
typename ValueT> void test_compare_both() {
211214
std::cout << __PRETTY_FUNCTION__ << std::endl;
212-
using Container = sycl::vec<ValueT, 2>;
215+
using Container = ContainerT<ValueT, 2>;
213216

214217
constexpr syclcompat::dim3 grid{1};
215218
constexpr syclcompat::dim3 threads{1};
@@ -236,9 +239,10 @@ void unordered_compare_both_kernel(Container *a, Container *b, bool *r) {
236239
*r = syclcompat::unordered_compare_both(*a, *b, std::equal_to<>());
237240
}
238241

239-
template <typename ValueT> void test_unordered_compare_both() {
242+
template <template <typename T, int Dim> typename ContainerT,
243+
typename ValueT> void test_unordered_compare_both() {
240244
std::cout << __PRETTY_FUNCTION__ << std::endl;
241-
using Container = sycl::vec<ValueT, 2>;
245+
using Container = ContainerT<ValueT, 2>;
242246

243247
constexpr syclcompat::dim3 grid{1};
244248
constexpr syclcompat::dim3 threads{1};
@@ -266,9 +270,10 @@ void compare_mask_kernel(Container *a, Container *b, unsigned *r) {
266270
*r = syclcompat::compare_mask(*a, *b, std::equal_to<>());
267271
}
268272

269-
template <typename ValueT> void test_compare_mask() {
273+
template <template <typename T, int Dim> typename ContainerT,
274+
typename ValueT> void test_compare_mask() {
270275
std::cout << __PRETTY_FUNCTION__ << std::endl;
271-
using Container = sycl::vec<ValueT, 2>;
276+
using Container = ContainerT<ValueT, 2>;
272277

273278
constexpr syclcompat::dim3 grid{1};
274279
constexpr syclcompat::dim3 threads{1};
@@ -314,9 +319,10 @@ void unordered_compare_mask_kernel(Container *a, Container *b, unsigned *r) {
314319
*r = syclcompat::unordered_compare_mask(*a, *b, std::equal_to<>());
315320
}
316321

317-
template <typename ValueT> void test_unordered_compare_mask() {
322+
template <template <typename T, int Dim> typename ContainerT,
323+
typename ValueT> void test_unordered_compare_mask() {
318324
std::cout << __PRETTY_FUNCTION__ << std::endl;
319-
using Container = sycl::vec<ValueT, 2>;
325+
using Container = ContainerT<ValueT, 2>;
320326

321327
constexpr syclcompat::dim3 grid{1};
322328
constexpr syclcompat::dim3 threads{1};
@@ -360,12 +366,18 @@ template <typename ValueT> void test_unordered_compare_mask() {
360366
int main() {
361367
INSTANTIATE_ALL_TYPES(fp_type_list, test_compare);
362368
INSTANTIATE_ALL_TYPES(fp_type_list, test_unordered_compare);
363-
INSTANTIATE_ALL_TYPES(fp_type_list, test_compare_vec);
364-
INSTANTIATE_ALL_TYPES(fp_type_list, test_unordered_compare_vec);
365-
INSTANTIATE_ALL_TYPES(fp_type_list, test_compare_both);
366-
INSTANTIATE_ALL_TYPES(fp_type_list, test_unordered_compare_both);
367-
INSTANTIATE_ALL_TYPES(fp_type_list, test_compare_mask);
368-
INSTANTIATE_ALL_TYPES(fp_type_list, test_unordered_compare_mask);
369+
INSTANTIATE_ALL_CONTAINER_TYPES(fp_type_list, sycl::vec, test_compare_vec);
370+
INSTANTIATE_ALL_CONTAINER_TYPES(fp_type_list, sycl::marray, test_compare_vec);
371+
INSTANTIATE_ALL_CONTAINER_TYPES(fp_type_list, sycl::vec, test_unordered_compare_vec);
372+
INSTANTIATE_ALL_CONTAINER_TYPES(fp_type_list, sycl::marray, test_unordered_compare_vec);
373+
INSTANTIATE_ALL_CONTAINER_TYPES(fp_type_list, sycl::vec, test_compare_both);
374+
INSTANTIATE_ALL_CONTAINER_TYPES(fp_type_list, sycl::marray, test_compare_both);
375+
INSTANTIATE_ALL_CONTAINER_TYPES(fp_type_list, sycl::vec, test_unordered_compare_both);
376+
INSTANTIATE_ALL_CONTAINER_TYPES(fp_type_list, sycl::marray, test_unordered_compare_both);
377+
INSTANTIATE_ALL_CONTAINER_TYPES(fp_type_list, sycl::vec, test_compare_mask);
378+
INSTANTIATE_ALL_CONTAINER_TYPES(fp_type_list, sycl::marray, test_compare_mask);
379+
INSTANTIATE_ALL_CONTAINER_TYPES(fp_type_list, sycl::vec, test_unordered_compare_mask);
380+
INSTANTIATE_ALL_CONTAINER_TYPES(fp_type_list, sycl::marray, test_unordered_compare_mask);
369381

370382
return 0;
371383
}

0 commit comments

Comments
 (0)