Skip to content

Commit 320d1a8

Browse files
committed
[SYCL] Add e2e test for llvm.scmp/ucmp.*
Signed-off-by: Marcos Maronas <[email protected]>
1 parent 1dfc3d7 commit 320d1a8

File tree

1 file changed

+150
-0
lines changed

1 file changed

+150
-0
lines changed
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
// RUN: %{build} -Wno-error=psabi -o %t.out
2+
// RUN: %{run} %t.out
3+
4+
#include <sycl/detail/core.hpp>
5+
6+
// Define vector types for different integer bit widths. We need these to
7+
// trigger llvm.scmp/ucmp for vector types. std::array or sycl::vec don't
8+
// trigger these, as they are not lowered to vector types.
9+
typedef int8_t v4i8_t __attribute__((ext_vector_type(4)));
10+
typedef int16_t v4i16_t __attribute__((ext_vector_type(4)));
11+
typedef int32_t v4i32_t __attribute__((ext_vector_type(4)));
12+
typedef int64_t v4i64_t __attribute__((ext_vector_type(4)));
13+
typedef uint8_t v4u8_t __attribute__((ext_vector_type(4)));
14+
typedef uint16_t v4u16_t __attribute__((ext_vector_type(4)));
15+
typedef uint32_t v4u32_t __attribute__((ext_vector_type(4)));
16+
typedef uint64_t v4u64_t __attribute__((ext_vector_type(4)));
17+
18+
// Check if a given type is a vector type or not. Used in submitAndCheck to
19+
// branch the check: we need element-wise comparison for vector types. Default
20+
// case: T is not a vector type.
21+
template <typename T>
22+
struct is_vector : std::false_type {};
23+
// Specialization for vector types. If T has
24+
// __attribute__((ext_vector_type(N))), then it's a vector type.
25+
template <typename T, std::size_t N>
26+
struct is_vector<T __attribute__((ext_vector_type(N)))> : std::true_type {};
27+
template <typename T>
28+
inline constexpr bool is_vector_v = is_vector<T>::value;
29+
30+
// Get the length of a vector type. Used in submitAndCheck to iterate over the
31+
// elements of the vector type. Default case: length is 1.
32+
template <typename T>
33+
struct vector_length {
34+
static constexpr std::size_t value = 1;
35+
};
36+
// Specialization for vector types. If T has
37+
// __attribute__((ext_vector_type(N))), then the length is N.
38+
template <typename T, std::size_t N>
39+
struct vector_length<T __attribute__((ext_vector_type(N)))> {
40+
static constexpr std::size_t value = N;
41+
};
42+
template <typename T>
43+
inline constexpr std::size_t vector_length_v = vector_length<T>::value;
44+
45+
// Get the element type of a vector type. Used in submitVecCombinations to
46+
// convert unsigned vector types to signed vector types for return type. Primary
47+
// template for element_type.
48+
template <typename T>
49+
struct element_type;
50+
// Specialization for vector types. If T has __attribute__((ext_vector_type(N))), return T.
51+
template <typename T, int N>
52+
struct element_type<T __attribute__((ext_vector_type(N)))> {
53+
using type = T;
54+
};
55+
// Helper alias template.
56+
template <typename T>
57+
using element_type_t = typename element_type<T>::type;
58+
59+
// TypeList for packing the types that we want to test.
60+
// Base case for variadic template recursion.
61+
template <typename...>
62+
struct TypeList {};
63+
64+
// Function to trigger llvm.scmp/ucmp.
65+
template <typename RetTy, typename ArgTy>
66+
void compare(RetTy &res, ArgTy x, ArgTy y) {
67+
auto lessOrEq = (x <= y);
68+
auto lessThan = (x < y);
69+
res = lessOrEq ? (lessThan ? RetTy(-1) : RetTy(0)) : RetTy(1);
70+
}
71+
72+
// Function to submit kernel and check device result with host result.
73+
template <typename RetTy, typename ArgTy>
74+
void submitAndCheck(sycl::queue &q, ArgTy x, ArgTy y) {
75+
RetTy res;
76+
{
77+
sycl::buffer<RetTy, 1> res_b{&res, 1};
78+
q.submit([&](sycl::handler &cgh) {
79+
sycl::accessor acc{res_b, cgh, sycl::write_only};
80+
cgh.single_task<>([=] {
81+
RetTy tmp;
82+
compare<RetTy, ArgTy>(tmp, x, y);
83+
acc[0] = tmp;
84+
});
85+
});
86+
}
87+
RetTy expectedRes;
88+
compare<RetTy, ArgTy>(expectedRes, x, y);
89+
if constexpr (is_vector_v<RetTy>) {
90+
for (int i = 0; i < vector_length_v<RetTy>; ++i) {
91+
assert(res[i] == expectedRes[i]);
92+
}
93+
} else {
94+
assert(res == expectedRes);
95+
}
96+
}
97+
98+
// Helper to call submitAndCheck for each combination.
99+
template <typename RetTypes, typename ArgTypes>
100+
void submitAndCheckCombination(sycl::queue &q, int x, int y) {
101+
submitAndCheck<RetTypes, ArgTypes>(q, x, y);
102+
}
103+
104+
// Function to generate all the combinations possible with the two type lists.
105+
// It implements the following pseudocode :
106+
// foreach RetTy : RetTypes
107+
// foreach ArgTy : ArgTypes
108+
// submitAndCheck<RetTy, ArgTy>(q, x, y);
109+
110+
// Recursive case to generate combinations.
111+
template <typename RetType, typename... RetTypes, typename... ArgTypes>
112+
void submitCombinations(sycl::queue &q, int x, int y, TypeList<RetType, RetTypes...>, TypeList<ArgTypes...>) {
113+
(submitAndCheckCombination<RetType, ArgTypes>(q, x, y), ...);
114+
submitCombinations(q, x, y, TypeList<RetTypes...>{}, TypeList<ArgTypes...>{});
115+
}
116+
// Base case to stop recursion.
117+
template <typename... ArgTypes>
118+
void submitCombinations(sycl::queue &, int, int, TypeList<>, TypeList<ArgTypes...>) {}
119+
120+
// Function to generate all the combinations out of the given list.
121+
// It implements the following pseudocode :
122+
// foreach ArgTy : ArgTypes
123+
// submitAndCheck<ArgTy, ArgTy>(q, x, y);
124+
125+
// Recursive case to generate combinations.
126+
template <typename ArgType, typename... ArgTypes>
127+
void submitVecCombinations(sycl::queue &q, int x, int y, TypeList<ArgType, ArgTypes...>) {
128+
// Use signed types for return type, as it may return -1.
129+
using ElemType = std::make_signed_t<element_type_t<ArgType>>;
130+
using RetType = ElemType __attribute__((ext_vector_type(vector_length_v<ArgType>)));
131+
submitAndCheckCombination<RetType, ArgType>(q, x, y);
132+
submitVecCombinations(q, x, y, TypeList<ArgTypes...>{});
133+
}
134+
// Base case to stop recursion.
135+
void submitVecCombinations(sycl::queue &, int, int, TypeList<>) {}
136+
137+
int main(int argc, char **argv) {
138+
sycl::queue q;
139+
// RetTypes includes only signed types because it may return -1.
140+
using RetTypes = TypeList<int8_t, int16_t, int32_t, int64_t>;
141+
using ArgTypes = TypeList<int8_t, int16_t, int32_t, int64_t, uint8_t, uint16_t, uint32_t, uint64_t>;
142+
submitCombinations(q, 50, 49, RetTypes{}, ArgTypes{});
143+
submitCombinations(q, 50, 50, RetTypes{}, ArgTypes{});
144+
submitCombinations(q, 50, 51, RetTypes{}, ArgTypes{});
145+
using VecTypes = TypeList<v4i8_t, v4i16_t, v4i32_t, v4i64_t, v4u8_t, v4u16_t, v4u32_t, v4u64_t>;
146+
submitVecCombinations(q, 50, 49, VecTypes{});
147+
submitVecCombinations(q, 50, 50, VecTypes{});
148+
submitVecCombinations(q, 50, 51, VecTypes{});
149+
return 0;
150+
}

0 commit comments

Comments
 (0)