1+ #include " common.h"
2+
3+ struct binops_tests {
4+ template <typename T, size_t ... I, size_t N = sizeof ...(I)>
5+ __host__ __device__ void operator ()(generator<T> gen, std::index_sequence<I...>) {
6+ T x[N] = {gen.next (I)...};
7+ T y[N] = {gen.next (I)...};
8+
9+ kf::vec<T, N> a = {x[I]...};
10+ kf::vec<T, N> b = {y[I]...};
11+ kf::vec<T, N> c;
12+
13+ // Arithmetic
14+ c = a + b;
15+ ASSERT (equals (T (x[I] + y[I]), c[I]) && ...);
16+
17+ c = a - b;
18+ ASSERT (equals (T (x[I] - y[I]), c[I]) && ...);
19+
20+ c = a * b;
21+ ASSERT (equals (T (x[I] * y[I]), c[I]) && ...);
22+
23+ // Results in division by zero
24+ // c = a / b;
25+ // ASSERT(equals(T(x[I] / y[I]), c[I]) && ...);
26+
27+ // Results in division by zero
28+ // c = a % b;
29+ // ASSERT(equals(T(x[I] % y[I]), c[I]) && ...);
30+
31+ // Comparison
32+ c = a < b;
33+ ASSERT (equals (T (x[I] < y[I]), c[I]) && ...);
34+
35+ c = a > b;
36+ ASSERT (equals (T (x[I] > y[I]), c[I]) && ...);
37+
38+ c = a <= b;
39+ ASSERT (equals (T (x[I] <= y[I]), c[I]) && ...);
40+
41+ c = a >= b;
42+ ASSERT (equals (T (x[I] >= y[I]), c[I]) && ...);
43+
44+ c = a == b;
45+ ASSERT (equals (T (x[I] == y[I]), c[I]) && ...);
46+
47+ c = a != b;
48+ ASSERT (equals (T (x[I] != y[I]), c[I]) && ...);
49+
50+ // Assignment
51+ c = a;
52+ c += b;
53+ ASSERT (equals (T (x[I] + y[I]), c[I]) && ...);
54+
55+ c = a;
56+ c -= b;
57+ ASSERT (equals (T (x[I] - y[I]), c[I]) && ...);
58+
59+ c = a;
60+ c *= b;
61+ ASSERT (equals (T (x[I] * y[I]), c[I]) && ...);
62+ }
63+ };
64+
65+ REGISTER_TEST_CASE (" binary operators" , binops_tests, bool , int , float , double )
66+ REGISTER_TEST_CASE_GPU(" binary operators" , binops_tests, __half, __nv_bfloat16)
67+
68+ struct binops_float_tests {
69+ template <typename T, size_t ... I, size_t N = sizeof ...(I)>
70+ __host__ __device__ void operator ()(generator<T> gen, std::index_sequence<I...>) {
71+ T x[N] = {gen.next (I)...};
72+ T y[N] = {gen.next (I)...};
73+
74+ kf::vec<T, N> a = {x[I]...};
75+ kf::vec<T, N> b = {y[I]...};
76+ kf::vec<T, N> c;
77+
78+ c = a / b;
79+ ASSERT (equals (T (x[I] / y[I]), c[I]) && ...);
80+
81+ // remainder is not support for fp16
82+ if constexpr (is_none_of<T, __half, __nv_bfloat16>) {
83+ // c = a % b;
84+ // ASSERT(equals(T(fmod(x[I], y[I])), c[I]) && ...);
85+ }
86+ }
87+ };
88+
89+ REGISTER_TEST_CASE (" binary float operators" , binops_float_tests, float , double )
90+ REGISTER_TEST_CASE_GPU(" binary float operators" , binops_float_tests, __half, __nv_bfloat16)
91+
92+ struct minmax_tests {
93+ template <typename T, size_t ... I, size_t N = sizeof ...(I)>
94+ __host__ __device__ void operator ()(generator<T> gen, std::index_sequence<I...>) {
95+ T x[N] = {gen.next (I)...};
96+ T y[N] = {gen.next (I)...};
97+
98+ kf::vec<T, N> a = {x[I]...};
99+ kf::vec<T, N> b = {y[I]...};
100+
101+ kf::vec<T, N> lo = min (a, b);
102+ kf::vec<T, N> hi = max (a, b);
103+
104+ if constexpr (is_one_of<T, double >) {
105+ ASSERT (equals (fmin (a[I], b[I]), lo[I]) && ...);
106+ ASSERT (equals (fmax (a[I], b[I]), hi[I]) && ...);
107+ } else if constexpr (is_one_of<T, float >) {
108+ ASSERT (equals (fminf (a[I], b[I]), lo[I]) && ...);
109+ ASSERT (equals (fmaxf (a[I], b[I]), hi[I]) && ...);
110+ } else if constexpr (is_one_of<T, __half, __nv_bfloat16>) {
111+ ASSERT (equals (__hmin (a[I], b[I]), lo[I]) && ...);
112+ ASSERT (equals (__hmax (a[I], b[I]), hi[I]) && ...);
113+ } else {
114+ ASSERT (equals (x[I] < y[I] ? x[I] : y[I], lo[I]) && ...);
115+ ASSERT (equals (x[I] < y[I] ? y[I] : x[I], hi[I]) && ...);
116+ }
117+ }
118+ };
119+
120+ REGISTER_TEST_CASE (" min/max functions" , minmax_tests, bool , int , float , double )
121+ REGISTER_TEST_CASE_GPU(" min/max functions" , minmax_tests, __half, __nv_bfloat16)
122+
123+ struct cross_test {
124+ template <typename T>
125+ __host__ __device__ void operator ()(generator<T> gen) {
126+ kf::vec<T, 3 > a = {1 , 2 , 3 };
127+ kf::vec<T, 3 > b = {4 , 5 , 6 };
128+ kf::vec<T, 3 > c = cross (a, b);
129+
130+ ASSERT (c[0 ] == T (-3 ));
131+ ASSERT (c[1 ] == T (6 ));
132+ ASSERT (c[2 ] == T (-3 ));
133+ }
134+ };
135+
136+ REGISTER_TEST_CASE (" cross product" , cross_test, float , double )
137+ REGISTER_TEST_CASE_GPU(" cross product" , cross_test, __half, __nv_bfloat16)
0 commit comments