@@ -199,12 +199,18 @@ template <typename T> bool test() {
199199
200200 sycl::buffer<std::complex <T>> data (testcases, sycl::range{N});
201201 sycl::buffer<std::complex <T>> results (sycl::range{N});
202+ sycl::buffer<std::complex <T>> exp_conj (sycl::range{N});
203+ sycl::buffer<std::complex <T>> conj_exp (sycl::range{N});
202204
203205 q.submit ([&](sycl::handler &cgh) {
204206 sycl::accessor acc_data (data, cgh, sycl::read_only);
205- sycl::accessor acc (results, cgh, sycl::write_only);
207+ sycl::accessor acc_results (results, cgh, sycl::write_only);
208+ sycl::accessor acc_exp_conj (exp_conj, cgh, sycl::write_only);
209+ sycl::accessor acc_conj_exp (conj_exp, cgh, sycl::write_only);
206210 cgh.parallel_for (sycl::range{N}, [=](sycl::item<1 > it) {
207- acc[it] = std::exp (acc_data[it]);
211+ acc_results[it] = std::exp (acc_data[it]);
212+ acc_exp_conj[it] = std::exp (std::conj (acc_data[it]));
213+ acc_conj_exp[it] = std::conj (std::exp (acc_data[it]));
208214 });
209215 }).wait_and_throw ();
210216
@@ -219,9 +225,18 @@ template <typename T> bool test() {
219225
220226 // Based on https://en.cppreference.com/w/cpp/numeric/complex/exp
221227 // z below refers to the argument passed to std::exp(complex<T>)
222- sycl::host_accessor acc (results);
228+ sycl::host_accessor acc_results (results);
229+ sycl::host_accessor acc_exp_conj (exp_conj);
230+ sycl::host_accessor acc_conj_exp (conj_exp);
223231 for (unsigned i = 0 ; i < N; ++i) {
224- std::complex <T> r = acc[i];
232+ // std::exp(std::conj(z)) == std::conj(std::exp(z))
233+ // NAN is not equal to NAN in floating-point arithmetic, therefore compare
234+ // only results without NAN
235+ if (!std::isnan (acc_exp_conj[i].real ()) &&
236+ !std::isnan (acc_exp_conj[i].imag ()))
237+ CHECK (acc_exp_conj[i] == acc_conj_exp[i], passed, i);
238+
239+ std::complex <T> r = acc_results[i];
225240 // If z is (+/-0, +0), the result is (1, +0)
226241 if (testcases[i].real () == 0 && testcases[i].imag () == 0 &&
227242 !std::signbit (testcases[i].imag ())) {
@@ -247,6 +262,33 @@ template <typename T> bool test() {
247262 CHECK (r.imag () == 0 , passed, i);
248263 CHECK (std::signbit (testcases[i].imag ()) == std::signbit (r.imag ()),
249264 passed, i);
265+ // If z is (-inf, y) (for any finite y), the result is +0cis(y) where
266+ // cis(y) is cos(y) + isin(y)
267+ } else if (std::isinf (testcases[i].real ()) &&
268+ std::signbit (testcases[i].real ()) &&
269+ std::isfinite (testcases[i].imag ())) {
270+ CHECK (r.real () == 0 , passed, i)
271+ CHECK (std::signbit (r.real ()) ==
272+ std::signbit (std::cos (testcases[i].imag ())),
273+ passed, i)
274+ CHECK (r.imag () == 0 , passed, i)
275+ CHECK (std::signbit (r.imag ()) ==
276+ std::signbit (std::sin (testcases[i].imag ())),
277+ passed, i)
278+ // If z is (+inf, y) (for any finite nonzero y), the result is +∞cis(y)
279+ // where cis(y) is cos(y) + isin(y)
280+ } else if (std::isinf (testcases[i].real ()) &&
281+ !std::signbit (testcases[i].real ()) &&
282+ std::isfinite (testcases[i].imag ()) &&
283+ testcases[i].imag () != 0 ) {
284+ CHECK (std::isinf (r.real ()), passed, i)
285+ CHECK (std::signbit (r.real ()) ==
286+ std::signbit (std::cos (testcases[i].imag ())),
287+ passed, i)
288+ CHECK (std::isinf (r.imag ()), passed, i)
289+ CHECK (std::signbit (r.imag ()) ==
290+ std::signbit (std::sin (testcases[i].imag ())),
291+ passed, i)
250292 // If z is (-inf, +inf), the result is (+/-0, +/-0) (signs are
251293 // unspecified)
252294 } else if (std::isinf (testcases[i].real ()) && testcases[i].real () < 0 &&
0 commit comments