From 6d4a278b2f1ff4c93e68c140aaedd2e360f058fb Mon Sep 17 00:00:00 2001 From: KornevNikita Date: Fri, 11 Oct 2024 03:01:52 -0700 Subject: [PATCH 1/2] [SYCL] Add missing special values to exp(complex) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit exp(x,NaN) (for any finite x) = (NaN,NaN) exp(NaN,y) (for any nonzero y) = (NaN,NaN) exp(x,+∞) (for any finite x) = (NaN,NaN) https://en.cppreference.com/w/cpp/numeric/complex/exp E2E: https://github.com/intel/llvm/pull/15666 --- libdevice/fallback-complex-fp64.cpp | 10 +- libdevice/fallback-complex.cpp | 10 +- .../exp/exp-std-complex-double-edge-cases.cpp | 12 + .../exp/exp-std-complex-edge-cases.hpp | 316 ++++++++++++++++++ .../exp/exp-std-complex-float-edge-cases.cpp | 11 + 5 files changed, 355 insertions(+), 4 deletions(-) create mode 100644 sycl/test-e2e/DeviceLib/exp/exp-std-complex-double-edge-cases.cpp create mode 100644 sycl/test-e2e/DeviceLib/exp/exp-std-complex-edge-cases.hpp create mode 100644 sycl/test-e2e/DeviceLib/exp/exp-std-complex-float-edge-cases.cpp diff --git a/libdevice/fallback-complex-fp64.cpp b/libdevice/fallback-complex-fp64.cpp index 273ece4358067..11803a1b72f83 100644 --- a/libdevice/fallback-complex-fp64.cpp +++ b/libdevice/fallback-complex-fp64.cpp @@ -149,8 +149,14 @@ double __complex__ __devicelib_cexp(double __complex__ z) { z_imag = NAN; return CMPLX(z_real, z_imag); } - } else if (__spirv_IsNan(z_real) && (z_imag == 0.0)) { - return z; + } else if (__spirv_IsNan(z_real)) { + if (z_imag == 0.0) + return z; + else /* z_imag != 0.0 */ + return CMPLX(NAN, NAN); + } else if (__spirv_IsFinite(z_real)) { + if (__spirv_IsNan(z_imag) || __spirv_IsInf(z_imag)) + return CMPLX(NAN, NAN); } double __e = __spirv_ocl_exp(z_real); double ret_real = __e * __spirv_ocl_cos(z_imag); diff --git a/libdevice/fallback-complex.cpp b/libdevice/fallback-complex.cpp index e3f58b9eeb019..a1fba8b71c581 100644 --- a/libdevice/fallback-complex.cpp +++ b/libdevice/fallback-complex.cpp @@ -150,8 +150,14 @@ float __complex__ __devicelib_cexpf(float __complex__ z) { z_imag = NAN; return CMPLXF(z_real, z_imag); } - } else if (__spirv_IsNan(z_real) && (z_imag == 0.0f)) { - return z; + } else if (__spirv_IsNan(z_real)) { + if (z_imag == 0.0f) + return z; + else /* z_imag != 0.0 */ + return CMPLXF(NAN, NAN); + } else if (__spirv_IsFinite(z_real)) { + if (__spirv_IsNan(z_imag) || __spirv_IsInf(z_imag)) + return CMPLXF(NAN, NAN); } float __e = __spirv_ocl_exp(z_real); float ret_real = __e * __spirv_ocl_cos(z_imag); diff --git a/sycl/test-e2e/DeviceLib/exp/exp-std-complex-double-edge-cases.cpp b/sycl/test-e2e/DeviceLib/exp/exp-std-complex-double-edge-cases.cpp new file mode 100644 index 0000000000000..791bda2cf1e61 --- /dev/null +++ b/sycl/test-e2e/DeviceLib/exp/exp-std-complex-double-edge-cases.cpp @@ -0,0 +1,12 @@ +// This test checks edge cases handling for std::exp(std::complex) used +// in SYCL kernels. +// +// REQUIRES: aspect-fp64 +// UNSUPPORTED: hip || cuda +// +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +#include "exp-std-complex-edge-cases.hpp" + +int main() { return test(); } diff --git a/sycl/test-e2e/DeviceLib/exp/exp-std-complex-edge-cases.hpp b/sycl/test-e2e/DeviceLib/exp/exp-std-complex-edge-cases.hpp new file mode 100644 index 0000000000000..92362c6186979 --- /dev/null +++ b/sycl/test-e2e/DeviceLib/exp/exp-std-complex-edge-cases.hpp @@ -0,0 +1,316 @@ +// This test checks edge cases handling for std::exp(std::complex) used +// in SYCL kernels. +// +// REQUIRES: aspect-fp64 +// UNSUPPORTED: hip || cuda +// +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +#include + +#include +#include +#include + +bool check(bool cond, const std::string &cond_str, int line, + unsigned testcase) { + if (!cond) { + std::cout << "Assertion " << cond_str << " (line " << line + << ") failed for testcase #" << testcase << std::endl; + return false; + } + + return true; +} + +template bool test() { + // To simplify maintanence of those comments specifying indexes of test cases + // in the array below, please add new test cases at the end of the list. + constexpr std::complex testcases[] = { + /* 0 */ std::complex(1.e-6, 1.e-6), + /* 1 */ std::complex(-1.e-6, 1.e-6), + /* 2 */ std::complex(-1.e-6, -1.e-6), + /* 3 */ std::complex(1.e-6, -1.e-6), + + /* 4 */ std::complex(1.e+6, 1.e-6), + /* 5 */ std::complex(-1.e+6, 1.e-6), + /* 6 */ std::complex(-1.e+6, -1.e-6), + /* 7 */ std::complex(1.e+6, -1.e-6), + + /* 8 */ std::complex(1.e-6, 1.e+6), + /* 9 */ std::complex(-1.e-6, 1.e+6), + /* 10 */ std::complex(-1.e-6, -1.e+6), + /* 11 */ std::complex(1.e-6, -1.e+6), + + /* 12 */ std::complex(1.e+6, 1.e+6), + /* 13 */ std::complex(-1.e+6, 1.e+6), + /* 14 */ std::complex(-1.e+6, -1.e+6), + /* 15 */ std::complex(1.e+6, -1.e+6), + + /* 16 */ std::complex(-0, -1.e-6), + /* 17 */ std::complex(-0, 1.e-6), + /* 18 */ std::complex(-0, 1.e+6), + /* 19 */ std::complex(-0, -1.e+6), + /* 20 */ std::complex(0, -1.e-6), + /* 21 */ std::complex(0, 1.e-6), + /* 22 */ std::complex(0, 1.e+6), + /* 23 */ std::complex(0, -1.e+6), + + /* 24 */ std::complex(-1.e-6, -0), + /* 25 */ std::complex(1.e-6, -0), + /* 26 */ std::complex(1.e+6, -0), + /* 27 */ std::complex(-1.e+6, -0), + /* 28 */ std::complex(-1.e-6, 0), + /* 29 */ std::complex(1.e-6, 0), + /* 30 */ std::complex(1.e+6, 0), + /* 31 */ std::complex(-1.e+6, 0), + + /* 32 */ std::complex(NAN, NAN), + /* 33 */ std::complex(-INFINITY, NAN), + /* 34 */ std::complex(-2, NAN), + /* 35 */ std::complex(-1, NAN), + /* 36 */ std::complex(-0.5, NAN), + /* 37 */ std::complex(-0., NAN), + /* 38 */ std::complex(+0., NAN), + /* 39 */ std::complex(0.5, NAN), + /* 40 */ std::complex(1, NAN), + /* 41 */ std::complex(2, NAN), + /* 42 */ std::complex(INFINITY, NAN), + + /* 43 */ std::complex(NAN, -INFINITY), + /* 44 */ std::complex(-INFINITY, -INFINITY), + /* 45 */ std::complex(-2, -INFINITY), + /* 46 */ std::complex(-1, -INFINITY), + /* 47 */ std::complex(-0.5, -INFINITY), + /* 48 */ std::complex(-0., -INFINITY), + /* 49 */ std::complex(+0., -INFINITY), + /* 50 */ std::complex(0.5, -INFINITY), + /* 51 */ std::complex(1, -INFINITY), + /* 52 */ std::complex(2, -INFINITY), + /* 53 */ std::complex(INFINITY, -INFINITY), + + /* 54 */ std::complex(NAN, -2), + /* 55 */ std::complex(-INFINITY, -2), + /* 56 */ std::complex(-2, -2), + /* 57 */ std::complex(-1, -2), + /* 58 */ std::complex(-0.5, -2), + /* 59 */ std::complex(-0., -2), + /* 60 */ std::complex(+0., -2), + /* 61 */ std::complex(0.5, -2), + /* 62 */ std::complex(1, -2), + /* 63 */ std::complex(2, -2), + /* 64 */ std::complex(INFINITY, -2), + + /* 65 */ std::complex(NAN, -1), + /* 66 */ std::complex(-INFINITY, -1), + /* 67 */ std::complex(-2, -1), + /* 68 */ std::complex(-1, -1), + /* 69 */ std::complex(-0.5, -1), + /* 70 */ std::complex(-0., -1), + /* 71 */ std::complex(+0., -1), + /* 72 */ std::complex(0.5, -1), + /* 73 */ std::complex(1, -1), + /* 74 */ std::complex(2, -1), + /* 75 */ std::complex(INFINITY, -1), + + /* 76 */ std::complex(NAN, -0.5), + /* 77 */ std::complex(-INFINITY, -0.5), + /* 78 */ std::complex(-2, -0.5), + /* 79 */ std::complex(-1, -0.5), + /* 80 */ std::complex(-0.5, -0.5), + /* 81 */ std::complex(-0., -0.5), + /* 82 */ std::complex(+0., -0.5), + /* 83 */ std::complex(0.5, -0.5), + /* 84 */ std::complex(1, -0.5), + /* 85 */ std::complex(2, -0.5), + /* 86 */ std::complex(INFINITY, -0.5), + + /* 87 */ std::complex(NAN, -0.), + /* 88 */ std::complex(-INFINITY, -0.), + /* 89 */ std::complex(-2, -0.), + /* 90 */ std::complex(-1, -0.), + /* 91 */ std::complex(-0.5, -0.), + /* 92 */ std::complex(-0., -0.), + /* 93 */ std::complex(+0., -0.), + /* 94 */ std::complex(0.5, -0.), + /* 95 */ std::complex(1, -0.), + /* 96 */ std::complex(2, -0.), + /* 97 */ std::complex(INFINITY, -0.), + + /* 98 */ std::complex(NAN, +0.), + /* 99 */ std::complex(-INFINITY, +0.), + /* 100 */ std::complex(-2, +0.), + /* 101 */ std::complex(-1, +0.), + /* 102 */ std::complex(-0.5, +0.), + /* 103 */ std::complex(-0., +0.), + /* 104 */ std::complex(+0., +0.), + /* 105 */ std::complex(0.5, +0.), + /* 106 */ std::complex(1, +0.), + /* 107 */ std::complex(2, +0.), + /* 108 */ std::complex(INFINITY, +0.), + + /* 109 */ std::complex(NAN, 0.5), + /* 110 */ std::complex(-INFINITY, 0.5), + /* 111 */ std::complex(-2, 0.5), + /* 112 */ std::complex(-1, 0.5), + /* 113 */ std::complex(-0.5, 0.5), + /* 114 */ std::complex(-0., 0.5), + /* 115 */ std::complex(+0., 0.5), + /* 116 */ std::complex(0.5, 0.5), + /* 117 */ std::complex(1, 0.5), + /* 118 */ std::complex(2, 0.5), + /* 119 */ std::complex(INFINITY, 0.5), + + /* 120 */ std::complex(NAN, 1), + /* 121 */ std::complex(-INFINITY, 1), + /* 122 */ std::complex(-2, 1), + /* 123 */ std::complex(-1, 1), + /* 124 */ std::complex(-0.5, 1), + /* 125 */ std::complex(-0., 1), + /* 126 */ std::complex(+0., 1), + /* 127 */ std::complex(0.5, 1), + /* 128 */ std::complex(1, 1), + /* 129 */ std::complex(2, 1), + /* 130 */ std::complex(INFINITY, 1), + + /* 131 */ std::complex(NAN, 2), + /* 132 */ std::complex(-INFINITY, 2), + /* 133 */ std::complex(-2, 2), + /* 134 */ std::complex(-1, 2), + /* 135 */ std::complex(-0.5, 2), + /* 136 */ std::complex(-0., 2), + /* 137 */ std::complex(+0., 2), + /* 138 */ std::complex(0.5, 2), + /* 139 */ std::complex(1, 2), + /* 140 */ std::complex(2, 2), + /* 141 */ std::complex(INFINITY, 2), + + /* 142 */ std::complex(NAN, INFINITY), + /* 143 */ std::complex(-INFINITY, INFINITY), + /* 144 */ std::complex(-2, INFINITY), + /* 145 */ std::complex(-1, INFINITY), + /* 146 */ std::complex(-0.5, INFINITY), + /* 147 */ std::complex(-0., INFINITY), + /* 148 */ std::complex(+0., INFINITY), + /* 149 */ std::complex(0.5, INFINITY), + /* 150 */ std::complex(1, INFINITY), + /* 151 */ std::complex(2, INFINITY), + /* 152 */ std::complex(INFINITY, INFINITY)}; + + try { + sycl::queue q; + + constexpr unsigned N = sizeof(testcases) / sizeof(testcases[0]); + + sycl::buffer> results(sycl::range{N}); + + q.submit([&](sycl::handler &cgh) { + sycl::accessor acc(results, cgh, sycl::write_only); + cgh.parallel_for(sycl::range{N}, [=](sycl::item<1> it) { + acc[it] = std::exp(testcases[it]); + }); + }).wait_and_throw(); + + bool passed = true; + + // Note: this macro is expected to be used within a loop +#define CHECK(cond, pass_marker, ...) \ + if (!check((cond), #cond, __LINE__, __VA_ARGS__)) { \ + pass_marker = false; \ + continue; \ + } + + // Based on https://en.cppreference.com/w/cpp/numeric/complex/exp + // z below refers to the argument passed to std::exp(complex) + sycl::host_accessor acc(results); + for (unsigned i = 0; i < N; ++i) { + std::complex r = acc[i]; + // If z is (+/-0, +0), the result is (1, +0) + if (testcases[i].real() == 0 && testcases[i].imag() == 0) { + CHECK(r.real() == 1.0, passed, i); + CHECK(r.imag() == 0, passed, i); + CHECK(std::signbit(testcases[i].imag()) == std::signbit(r.imag()), + passed, i); + // If z is (x, +inf) (for any finite x), the result is (NaN, NaN) + } else if (std::isfinite(testcases[i].real()) && + std::isinf(testcases[i].imag())) { + CHECK(std::isnan(r.real()), passed, i); + CHECK(std::isnan(r.imag()), passed, i); + // If z is (x, NaN) (for any finite x), the result is (NaN, NaN) + } else if (std::isfinite(testcases[i].real()) && + std::isnan(testcases[i].imag())) { + CHECK(std::isnan(r.real()), passed, i); + CHECK(std::isnan(r.imag()), passed, i); + // If z is (+inf, +0), the result is (+inf, +0) + } else if (std::isinf(testcases[i].real()) && testcases[i].real() > 0 && + testcases[i].imag() == 0) { + CHECK(std::isinf(r.real()), passed, i); + CHECK(r.real() > 0, passed, i); + CHECK(r.imag() == 0, passed, i); + CHECK(std::signbit(testcases[i].imag()) == std::signbit(r.imag()), + passed, i); + // If z is (-inf, +inf), the result is (+/-0, +/-0) (signs are + // unspecified) + } else if (std::isinf(testcases[i].real()) && testcases[i].real() < 0 && + std::isinf(testcases[i].imag())) { + CHECK(r.real() == 0, passed, i); + CHECK(r.imag() == 0, passed, i); + // If z is (+inf, +inf), the result is (+/-inf, NaN), (the sign of the + // real part is unspecified) + } else if (std::isinf(testcases[i].real()) && testcases[i].real() > 0 && + std::isinf(testcases[i].imag())) { + CHECK(std::isinf(r.real()), passed, i); + CHECK(std::isnan(r.imag()), passed, i); + // If z is (-inf, NaN), the result is (+/-0, +/-0) (signs are + // unspecified) + } else if (std::isinf(testcases[i].real()) && testcases[i].real() < 0 && + std::isnan(testcases[i].imag())) { + CHECK(r.real() == 0, passed, i); + CHECK(r.imag() == 0, passed, i); + // If z is (+inf, NaN), the result is (+/-inf, NaN) (the sign of the + // real part is unspecified) + } else if (std::isinf(testcases[i].real()) && testcases[i].real() > 0 && + std::isnan(testcases[i].imag())) { + CHECK(std::isinf(r.real()), passed, i); + CHECK(std::isnan(r.imag()), passed, i); + // If z is (NaN, +0), the result is (NaN, +0) + } else if (std::isnan(testcases[i].real()) && testcases[i].imag() == 0) { + CHECK(std::isnan(r.real()), passed, i); + CHECK(r.imag() == 0, passed, i); + CHECK(std::signbit(testcases[i].imag()) == std::signbit(r.imag()), + passed, i); + // If z is (NaN, y) (for any nonzero y), the result is (NaN,NaN) + } else if (std::isnan(testcases[i].real()) && testcases[i].imag() != 0) { + CHECK(std::isnan(r.real()), passed, i); + CHECK(std::isnan(r.imag()), passed, i); + // If z is (NaN, NaN), the result is (NaN, NaN) + } else if (std::isnan(testcases[i].real()) && + std::isnan(testcases[i].imag())) { + CHECK(std::isnan(r.real()), passed, i); + CHECK(std::isnan(r.imag()), passed, i); + // Those tests were taken from oneDPL, not sure what is the corner case + // they are covering here + } else if (std::isfinite(testcases[i].imag()) && + std::abs(testcases[i].imag()) <= 1) { + CHECK(!std::signbit(r.real()), passed, i); + CHECK(std::signbit(r.imag()) == std::signbit(testcases[i].imag()), + passed, i); + // Those tests were taken from oneDPL, not sure what is the corner case + // they are covering here + } else if (std::isinf(r.real()) && testcases[i].imag() == 0) { + CHECK(r.imag() == 0, passed, i); + CHECK(std::signbit(testcases[i].imag()) == std::signbit(r.imag()), + passed, i); + } + // FIXME: do we have the following cases covered? + // If z is (-inf, y) (for any finite y), the result is +0 cis(y) + // If z is (+inf, y) (for any finite nonzero y), the result is +inf cis(y) + } + + return passed ? 0 : 1; + } catch (sycl::exception &e) { + std::cout << "Caught sync sycl exception: " << e.what() << std::endl; + return 2; + } +} diff --git a/sycl/test-e2e/DeviceLib/exp/exp-std-complex-float-edge-cases.cpp b/sycl/test-e2e/DeviceLib/exp/exp-std-complex-float-edge-cases.cpp new file mode 100644 index 0000000000000..9ba1b932ba562 --- /dev/null +++ b/sycl/test-e2e/DeviceLib/exp/exp-std-complex-float-edge-cases.cpp @@ -0,0 +1,11 @@ +// This test checks edge cases handling for std::exp(std::complex) used +// in SYCL kernels. +// +// UNSUPPORTED: hip || cuda +// +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +#include "exp-std-complex-edge-cases.hpp" + +int main() { return test(); } From ad9963d538a544bbd5537f40736afdc12b2af4ab Mon Sep 17 00:00:00 2001 From: "Kornev, Nikita" Date: Mon, 14 Oct 2024 07:41:45 -0700 Subject: [PATCH 2/2] remove float part --- libdevice/fallback-complex.cpp | 10 ++-------- .../exp/exp-std-complex-float-edge-cases.cpp | 11 ----------- 2 files changed, 2 insertions(+), 19 deletions(-) delete mode 100644 sycl/test-e2e/DeviceLib/exp/exp-std-complex-float-edge-cases.cpp diff --git a/libdevice/fallback-complex.cpp b/libdevice/fallback-complex.cpp index a1fba8b71c581..e3f58b9eeb019 100644 --- a/libdevice/fallback-complex.cpp +++ b/libdevice/fallback-complex.cpp @@ -150,14 +150,8 @@ float __complex__ __devicelib_cexpf(float __complex__ z) { z_imag = NAN; return CMPLXF(z_real, z_imag); } - } else if (__spirv_IsNan(z_real)) { - if (z_imag == 0.0f) - return z; - else /* z_imag != 0.0 */ - return CMPLXF(NAN, NAN); - } else if (__spirv_IsFinite(z_real)) { - if (__spirv_IsNan(z_imag) || __spirv_IsInf(z_imag)) - return CMPLXF(NAN, NAN); + } else if (__spirv_IsNan(z_real) && (z_imag == 0.0f)) { + return z; } float __e = __spirv_ocl_exp(z_real); float ret_real = __e * __spirv_ocl_cos(z_imag); diff --git a/sycl/test-e2e/DeviceLib/exp/exp-std-complex-float-edge-cases.cpp b/sycl/test-e2e/DeviceLib/exp/exp-std-complex-float-edge-cases.cpp deleted file mode 100644 index 9ba1b932ba562..0000000000000 --- a/sycl/test-e2e/DeviceLib/exp/exp-std-complex-float-edge-cases.cpp +++ /dev/null @@ -1,11 +0,0 @@ -// This test checks edge cases handling for std::exp(std::complex) used -// in SYCL kernels. -// -// UNSUPPORTED: hip || cuda -// -// RUN: %{build} -o %t.out -// RUN: %{run} %t.out - -#include "exp-std-complex-edge-cases.hpp" - -int main() { return test(); }