Skip to content

Commit 3e7d6b3

Browse files
committed
[SYCL][E2E] Expand testing for std::exp(complex<double>)
1 parent 4797d65 commit 3e7d6b3

File tree

1 file changed

+324
-0
lines changed

1 file changed

+324
-0
lines changed
Lines changed: 324 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,324 @@
1+
// This test checks edge cases handling for std::exp(std::complex<double>) used
2+
// in SYCL kernels.
3+
//
4+
// REQUIRES: aspect-fp64
5+
//
6+
// RUN: %{build} -o %t.out
7+
// RUN: %{run} %t.out
8+
9+
#include <sycl/detail/core.hpp>
10+
11+
#include <cmath>
12+
#include <complex>
13+
#include <set>
14+
15+
// To simplify maintanence of those comments specifying indexes of test cases
16+
// in the array below, please add new test cases at the end of the list.
17+
constexpr std::complex<double> testcases[] = {
18+
/* 0 */ std::complex<double>(1.e-6, 1.e-6),
19+
/* 1 */ std::complex<double>(-1.e-6, 1.e-6),
20+
/* 2 */ std::complex<double>(-1.e-6, -1.e-6),
21+
/* 3 */ std::complex<double>(1.e-6, -1.e-6),
22+
23+
/* 4 */ std::complex<double>(1.e+6, 1.e-6),
24+
/* 5 */ std::complex<double>(-1.e+6, 1.e-6),
25+
/* 6 */ std::complex<double>(-1.e+6, -1.e-6),
26+
/* 7 */ std::complex<double>(1.e+6, -1.e-6),
27+
28+
/* 8 */ std::complex<double>(1.e-6, 1.e+6),
29+
/* 9 */ std::complex<double>(-1.e-6, 1.e+6),
30+
/* 10 */ std::complex<double>(-1.e-6, -1.e+6),
31+
/* 11 */ std::complex<double>(1.e-6, -1.e+6),
32+
33+
/* 12 */ std::complex<double>(1.e+6, 1.e+6),
34+
/* 13 */ std::complex<double>(-1.e+6, 1.e+6),
35+
/* 14 */ std::complex<double>(-1.e+6, -1.e+6),
36+
/* 15 */ std::complex<double>(1.e+6, -1.e+6),
37+
38+
/* 16 */ std::complex<double>(-0, -1.e-6),
39+
/* 17 */ std::complex<double>(-0, 1.e-6),
40+
/* 18 */ std::complex<double>(-0, 1.e+6),
41+
/* 19 */ std::complex<double>(-0, -1.e+6),
42+
/* 20 */ std::complex<double>(0, -1.e-6),
43+
/* 21 */ std::complex<double>(0, 1.e-6),
44+
/* 22 */ std::complex<double>(0, 1.e+6),
45+
/* 23 */ std::complex<double>(0, -1.e+6),
46+
47+
/* 24 */ std::complex<double>(-1.e-6, -0),
48+
/* 25 */ std::complex<double>(1.e-6, -0),
49+
/* 26 */ std::complex<double>(1.e+6, -0),
50+
/* 27 */ std::complex<double>(-1.e+6, -0),
51+
/* 28 */ std::complex<double>(-1.e-6, 0),
52+
/* 29 */ std::complex<double>(1.e-6, 0),
53+
/* 30 */ std::complex<double>(1.e+6, 0),
54+
/* 31 */ std::complex<double>(-1.e+6, 0),
55+
56+
/* 32 */ std::complex<double>(NAN, NAN),
57+
/* 33 */ std::complex<double>(-INFINITY, NAN),
58+
/* 34 */ std::complex<double>(-2, NAN),
59+
/* 35 */ std::complex<double>(-1, NAN),
60+
/* 36 */ std::complex<double>(-0.5, NAN),
61+
/* 37 */ std::complex<double>(-0., NAN),
62+
/* 38 */ std::complex<double>(+0., NAN),
63+
/* 39 */ std::complex<double>(0.5, NAN),
64+
/* 40 */ std::complex<double>(1, NAN),
65+
/* 41 */ std::complex<double>(2, NAN),
66+
/* 42 */ std::complex<double>(INFINITY, NAN),
67+
68+
/* 43 */ std::complex<double>(NAN, -INFINITY),
69+
/* 44 */ std::complex<double>(-INFINITY, -INFINITY),
70+
/* 45 */ std::complex<double>(-2, -INFINITY),
71+
/* 46 */ std::complex<double>(-1, -INFINITY),
72+
/* 47 */ std::complex<double>(-0.5, -INFINITY),
73+
/* 48 */ std::complex<double>(-0., -INFINITY),
74+
/* 49 */ std::complex<double>(+0., -INFINITY),
75+
/* 50 */ std::complex<double>(0.5, -INFINITY),
76+
/* 51 */ std::complex<double>(1, -INFINITY),
77+
/* 52 */ std::complex<double>(2, -INFINITY),
78+
/* 53 */ std::complex<double>(INFINITY, -INFINITY),
79+
80+
/* 54 */ std::complex<double>(NAN, -2),
81+
/* 55 */ std::complex<double>(-INFINITY, -2),
82+
/* 56 */ std::complex<double>(-2, -2),
83+
/* 57 */ std::complex<double>(-1, -2),
84+
/* 58 */ std::complex<double>(-0.5, -2),
85+
/* 59 */ std::complex<double>(-0., -2),
86+
/* 60 */ std::complex<double>(+0., -2),
87+
/* 61 */ std::complex<double>(0.5, -2),
88+
/* 62 */ std::complex<double>(1, -2),
89+
/* 63 */ std::complex<double>(2, -2),
90+
/* 64 */ std::complex<double>(INFINITY, -2),
91+
92+
/* 65 */ std::complex<double>(NAN, -1),
93+
/* 66 */ std::complex<double>(-INFINITY, -1),
94+
/* 67 */ std::complex<double>(-2, -1),
95+
/* 68 */ std::complex<double>(-1, -1),
96+
/* 69 */ std::complex<double>(-0.5, -1),
97+
/* 70 */ std::complex<double>(-0., -1),
98+
/* 71 */ std::complex<double>(+0., -1),
99+
/* 72 */ std::complex<double>(0.5, -1),
100+
/* 73 */ std::complex<double>(1, -1),
101+
/* 74 */ std::complex<double>(2, -1),
102+
/* 75 */ std::complex<double>(INFINITY, -1),
103+
104+
/* 76 */ std::complex<double>(NAN, -0.5),
105+
/* 77 */ std::complex<double>(-INFINITY, -0.5),
106+
/* 78 */ std::complex<double>(-2, -0.5),
107+
/* 79 */ std::complex<double>(-1, -0.5),
108+
/* 80 */ std::complex<double>(-0.5, -0.5),
109+
/* 81 */ std::complex<double>(-0., -0.5),
110+
/* 82 */ std::complex<double>(+0., -0.5),
111+
/* 83 */ std::complex<double>(0.5, -0.5),
112+
/* 84 */ std::complex<double>(1, -0.5),
113+
/* 85 */ std::complex<double>(2, -0.5),
114+
/* 86 */ std::complex<double>(INFINITY, -0.5),
115+
116+
/* 87 */ std::complex<double>(NAN, -0.),
117+
/* 88 */ std::complex<double>(-INFINITY, -0.),
118+
/* 89 */ std::complex<double>(-2, -0.),
119+
/* 90 */ std::complex<double>(-1, -0.),
120+
/* 91 */ std::complex<double>(-0.5, -0.),
121+
/* 92 */ std::complex<double>(-0., -0.),
122+
/* 93 */ std::complex<double>(+0., -0.),
123+
/* 94 */ std::complex<double>(0.5, -0.),
124+
/* 95 */ std::complex<double>(1, -0.),
125+
/* 96 */ std::complex<double>(2, -0.),
126+
/* 97 */ std::complex<double>(INFINITY, -0.),
127+
128+
/* 98 */ std::complex<double>(NAN, +0.),
129+
/* 99 */ std::complex<double>(-INFINITY, +0.),
130+
/* 100 */ std::complex<double>(-2, +0.),
131+
/* 101 */ std::complex<double>(-1, +0.),
132+
/* 102 */ std::complex<double>(-0.5, +0.),
133+
/* 103 */ std::complex<double>(-0., +0.),
134+
/* 104 */ std::complex<double>(+0., +0.),
135+
/* 105 */ std::complex<double>(0.5, +0.),
136+
/* 106 */ std::complex<double>(1, +0.),
137+
/* 107 */ std::complex<double>(2, +0.),
138+
/* 108 */ std::complex<double>(INFINITY, +0.),
139+
140+
/* 109 */ std::complex<double>(NAN, 0.5),
141+
/* 110 */ std::complex<double>(-INFINITY, 0.5),
142+
/* 111 */ std::complex<double>(-2, 0.5),
143+
/* 112 */ std::complex<double>(-1, 0.5),
144+
/* 113 */ std::complex<double>(-0.5, 0.5),
145+
/* 114 */ std::complex<double>(-0., 0.5),
146+
/* 115 */ std::complex<double>(+0., 0.5),
147+
/* 116 */ std::complex<double>(0.5, 0.5),
148+
/* 117 */ std::complex<double>(1, 0.5),
149+
/* 118 */ std::complex<double>(2, 0.5),
150+
/* 119 */ std::complex<double>(INFINITY, 0.5),
151+
152+
/* 120 */ std::complex<double>(NAN, 1),
153+
/* 121 */ std::complex<double>(-INFINITY, 1),
154+
/* 122 */ std::complex<double>(-2, 1),
155+
/* 123 */ std::complex<double>(-1, 1),
156+
/* 124 */ std::complex<double>(-0.5, 1),
157+
/* 125 */ std::complex<double>(-0., 1),
158+
/* 126 */ std::complex<double>(+0., 1),
159+
/* 127 */ std::complex<double>(0.5, 1),
160+
/* 128 */ std::complex<double>(1, 1),
161+
/* 129 */ std::complex<double>(2, 1),
162+
/* 130 */ std::complex<double>(INFINITY, 1),
163+
164+
/* 131 */ std::complex<double>(NAN, 2),
165+
/* 132 */ std::complex<double>(-INFINITY, 2),
166+
/* 133 */ std::complex<double>(-2, 2),
167+
/* 134 */ std::complex<double>(-1, 2),
168+
/* 135 */ std::complex<double>(-0.5, 2),
169+
/* 136 */ std::complex<double>(-0., 2),
170+
/* 137 */ std::complex<double>(+0., 2),
171+
/* 138 */ std::complex<double>(0.5, 2),
172+
/* 139 */ std::complex<double>(1, 2),
173+
/* 140 */ std::complex<double>(2, 2),
174+
/* 141 */ std::complex<double>(INFINITY, 2),
175+
176+
/* 142 */ std::complex<double>(NAN, INFINITY),
177+
/* 143 */ std::complex<double>(-INFINITY, INFINITY),
178+
/* 144 */ std::complex<double>(-2, INFINITY),
179+
/* 145 */ std::complex<double>(-1, INFINITY),
180+
/* 146 */ std::complex<double>(-0.5, INFINITY),
181+
/* 147 */ std::complex<double>(-0., INFINITY),
182+
/* 148 */ std::complex<double>(+0., INFINITY),
183+
/* 149 */ std::complex<double>(0.5, INFINITY),
184+
/* 150 */ std::complex<double>(1, INFINITY),
185+
/* 151 */ std::complex<double>(2, INFINITY),
186+
/* 152 */ std::complex<double>(INFINITY, INFINITY)};
187+
188+
bool check(bool cond, const std::string &cond_str, int line, unsigned testcase,
189+
const std::set<unsigned> &known_fails) {
190+
if (!cond && !known_fails.count(testcase)) {
191+
std::cout << "Assertion " << cond_str << " (line " << line
192+
<< ") failed for testcase #" << testcase << std::endl;
193+
return false;
194+
} else if (cond && known_fails.count(testcase)) {
195+
std::cout << "Assertion " << cond_str << " (line " << line
196+
<< ") passed for testcase #" << testcase << std::endl;
197+
std::cout << "However, it was recorded as a known failure and therefore "
198+
"the test needs to be updated"
199+
<< std::endl;
200+
return false;
201+
}
202+
return true;
203+
}
204+
205+
int main() try {
206+
sycl::queue q;
207+
208+
constexpr unsigned N = sizeof(testcases) / sizeof(testcases[0]);
209+
210+
sycl::buffer<std::complex<double>> results(sycl::range{N});
211+
212+
q.submit([&](sycl::handler &cgh) {
213+
sycl::accessor acc(results, cgh, sycl::write_only);
214+
cgh.parallel_for(sycl::range{N}, [=](sycl::item<1> it) {
215+
acc[it] = std::exp(testcases[it]);
216+
});
217+
}).wait_and_throw();
218+
219+
// FIXME: the set below should be empty and therefore removed
220+
std::set<unsigned> known_fails = {32, 34, 35, 36, 37, 38, 39, 40, 41,
221+
43, 45, 46, 47, 48, 49, 50, 51, 52,
222+
54, 65, 76, 109, 120, 131, 142, 144, 145,
223+
146, 147, 148, 149, 150, 151};
224+
225+
bool passed = true;
226+
227+
// Note: this macro is expected to be used within a loop
228+
#define CHECK(cond, pass_marker, ...) \
229+
if (!check((cond), #cond, __LINE__, __VA_ARGS__)) { \
230+
pass_marker = false; \
231+
continue; \
232+
}
233+
234+
// Based on https://en.cppreference.com/w/cpp/numeric/complex/exp
235+
// z below refers to the argument passed to std::exp(complex<double>)
236+
sycl::host_accessor acc(results);
237+
for (unsigned i = 0; i < N; ++i) {
238+
std::complex<double> r = acc[i];
239+
// If z is (+/-0, +0), the result is (1, +0)
240+
if (testcases[i].real() == 0 && testcases[i].imag() == 0) {
241+
CHECK(r.real() == 1.0, passed, i, known_fails);
242+
CHECK(r.imag() == 0, passed, i, known_fails);
243+
CHECK(std::signbit(testcases[i].imag()) == std::signbit(r.imag()), passed,
244+
i, known_fails);
245+
// If z is (x, +inf) (for any finite x), the result is (NaN, NaN)
246+
} else if (std::isfinite(testcases[i].real()) &&
247+
std::isinf(testcases[i].imag())) {
248+
CHECK(std::isnan(r.real()), passed, i, known_fails);
249+
CHECK(std::isnan(r.imag()), passed, i, known_fails);
250+
// If z is (x, NaN) (for any finite x), the result is (NaN, NaN)
251+
} else if (std::isfinite(testcases[i].real()) &&
252+
std::isnan(testcases[i].imag())) {
253+
CHECK(std::isnan(r.real()), passed, i, known_fails);
254+
CHECK(std::isnan(r.imag()), passed, i, known_fails);
255+
// If z is (+inf, +0), the result is (+inf, +0)
256+
} else if (std::isinf(testcases[i].real()) && testcases[i].real() > 0 &&
257+
testcases[i].imag() == 0) {
258+
CHECK(std::isinf(r.real()), passed, i, known_fails);
259+
CHECK(r.real() > 0, passed, i, known_fails);
260+
CHECK(r.imag() == 0, passed, i, known_fails);
261+
CHECK(std::signbit(testcases[i].imag()) == std::signbit(r.imag()), passed,
262+
i, known_fails);
263+
// If z is (-inf, +inf), the result is (+/-0, +/-0) (signs are
264+
// unspecified)
265+
} else if (std::isinf(testcases[i].real()) && testcases[i].real() < 0 &&
266+
std::isinf(testcases[i].imag())) {
267+
CHECK(r.real() == 0, passed, i, known_fails);
268+
CHECK(r.imag() == 0, passed, i, known_fails);
269+
// If z is (+inf, +inf), the result is (+/-inf, NaN), (the sign of the
270+
// real part is unspecified)
271+
} else if (std::isinf(testcases[i].real()) && testcases[i].real() > 0 &&
272+
std::isinf(testcases[i].imag())) {
273+
CHECK(std::isinf(r.real()), passed, i, known_fails);
274+
CHECK(std::isnan(r.imag()), passed, i, known_fails);
275+
// If z is (-inf, NaN), the result is (+/-0, +/-0) (signs are unspecified)
276+
} else if (std::isinf(testcases[i].real()) && testcases[i].real() < 0 &&
277+
std::isnan(testcases[i].imag())) {
278+
CHECK(r.real() == 0, passed, i, known_fails);
279+
CHECK(r.imag() == 0, passed, i, known_fails);
280+
// If z is (+inf, NaN), the result is (+/-inf, NaN) (the sign of the real
281+
// part is unspecified)
282+
} else if (std::isinf(testcases[i].real()) && testcases[i].real() > 0 &&
283+
std::isnan(testcases[i].imag())) {
284+
CHECK(std::isinf(r.real()), passed, i, known_fails);
285+
CHECK(std::isnan(r.imag()), passed, i, known_fails);
286+
// If z is (NaN, +0), the result is (NaN, +0)
287+
} else if (std::isnan(testcases[i].real()) && testcases[i].imag() == 0) {
288+
CHECK(std::isnan(r.real()), passed, i, known_fails);
289+
CHECK(r.imag() == 0, passed, i, known_fails);
290+
CHECK(std::signbit(testcases[i].imag()) == std::signbit(r.imag()), passed,
291+
i, known_fails);
292+
// If z is (NaN, y) (for any nonzero y), the result is (NaN,NaN)
293+
} else if (std::isnan(testcases[i].real()) && testcases[i].imag() != 0) {
294+
CHECK(std::isnan(r.real()), passed, i, known_fails);
295+
CHECK(std::isnan(r.imag()), passed, i, known_fails);
296+
// If z is (NaN, NaN), the result is (NaN, NaN)
297+
} else if (std::isnan(testcases[i].real()) &&
298+
std::isnan(testcases[i].imag())) {
299+
CHECK(std::isnan(r.real()), passed, i, known_fails);
300+
CHECK(std::isnan(r.imag()), passed, i, known_fails);
301+
// Those tests were taken from oneDPL, not sure what is the corner case
302+
// they are covering here
303+
} else if (std::isfinite(testcases[i].imag()) &&
304+
std::abs(testcases[i].imag()) <= 1) {
305+
CHECK(!std::signbit(r.real()), passed, i, known_fails);
306+
CHECK(std::signbit(r.imag()) == std::signbit(testcases[i].imag()), passed,
307+
i, known_fails);
308+
// Those tests were taken from oneDPL, not sure what is the corner case
309+
// they are covering here
310+
} else if (std::isinf(r.real()) && testcases[i].imag() == 0) {
311+
CHECK(r.imag() == 0, passed, i, known_fails);
312+
CHECK(std::signbit(testcases[i].imag()) == std::signbit(r.imag()), passed,
313+
i, known_fails);
314+
}
315+
// FIXME: do we have the following cases covered?
316+
// If z is (-inf, y) (for any finite y), the result is +0 cis(y)
317+
// If z is (+inf, y) (for any finite nonzero y), the result is +inf cis(y)
318+
}
319+
320+
return passed ? 0 : 1;
321+
} catch (sycl::exception &e) {
322+
std::cout << "Caught sync sycl exception: " << e.what() << std::endl;
323+
return 2;
324+
}

0 commit comments

Comments
 (0)