Skip to content

Commit 3008955

Browse files
[SYCL] Fix complex tanh (#20636)
The current implementation of the tanh function for the SYCL complex extension did not always return values precise enough, neither on device nor on host. This was caused by accumulated error from a call to cos. This commit changes the implementation of tanh to use a different way of calculating the result and tests it against std::tanh for complex numbers. --------- Signed-off-by: Larsen, Steffen <[email protected]>
1 parent 67ac635 commit 3008955

File tree

2 files changed

+155
-14
lines changed

2 files changed

+155
-14
lines changed

sycl/include/sycl/ext/oneapi/experimental/complex/detail/complex_math.hpp

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -493,22 +493,21 @@ template <class _Tp>
493493
__DPCPP_SYCL_EXTERNAL _SYCL_EXT_CPLX_INLINE_VISIBILITY
494494
typename std::enable_if_t<is_genfloat<_Tp>::value, complex<_Tp>>
495495
tanh(const complex<_Tp> &__x) {
496-
if (sycl::isinf(__x.real())) {
497-
if (!sycl::isfinite(__x.imag()))
498-
return complex<_Tp>(sycl::copysign(_Tp(1), __x.real()), _Tp(0));
499-
return complex<_Tp>(sycl::copysign(_Tp(1), __x.real()),
500-
sycl::copysign(_Tp(0), sycl::sin(_Tp(2) * __x.imag())));
501-
}
496+
if (sycl::isinf(__x.real()))
497+
return complex<_Tp>(
498+
sycl::copysign(_Tp(1), __x.real()),
499+
sycl::copysign(_Tp(0), sycl::isfinite(__x.imag())
500+
? sycl::sin(_Tp(2) * __x.imag())
501+
: _Tp(1)));
502502
if (sycl::isnan(__x.real()) && __x.imag() == 0)
503503
return __x;
504-
_Tp __2r(_Tp(2) * __x.real());
505-
_Tp __2i(_Tp(2) * __x.imag());
506-
_Tp __d(sycl::cosh(__2r) + sycl::cos(__2i));
507-
_Tp __2rsh(sycl::sinh(__2r));
508-
if (sycl::isinf(__2rsh) && sycl::isinf(__d))
509-
return complex<_Tp>(__2rsh > _Tp(0) ? _Tp(1) : _Tp(-1),
510-
__2i > _Tp(0) ? _Tp(0) : _Tp(-0.));
511-
return complex<_Tp>(__2rsh / __d, sycl::sin(__2i) / __d);
504+
complex<_Tp> sinh_x = sinh(__x);
505+
complex<_Tp> cosh_x = cosh(__x);
506+
if (sycl::isinf(sinh_x.real()) && sycl::isinf(cosh_x.real()))
507+
return complex<_Tp>(sinh_x.real() * cosh_x.real() > _Tp(0) ? _Tp(1)
508+
: _Tp(-1),
509+
__x.imag() > _Tp(0) ? _Tp(0) : _Tp(-0.));
510+
return sinh_x / cosh_x;
512511
}
513512

514513
// asin
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
// RUN: %clangxx -fsycl %s -o %t.out
2+
// RUN: %t.out
3+
4+
// Checks the results of tanh on certain complex numbers.
5+
6+
#define SYCL_EXT_ONEAPI_COMPLEX
7+
8+
#include <sycl/ext/oneapi/experimental/complex/complex.hpp>
9+
#include <sycl/sycl.hpp>
10+
11+
#include <complex>
12+
#include <limits>
13+
14+
namespace syclexp = sycl::ext::oneapi::experimental;
15+
16+
int Failures = 0;
17+
18+
template <typename T> bool FloatingPointEq(T LHS, T RHS) {
19+
if (std::isnan(LHS))
20+
return std::isnan(RHS);
21+
// Allow some rounding differences, but minimal.
22+
return std::abs(LHS - RHS) < T{0.0001};
23+
}
24+
25+
#define CHECK_TANH_RESULT(REAL, IMAG, T) \
26+
{ \
27+
syclexp::complex<T> sycl_res = \
28+
syclexp::tanh(syclexp::complex<T>{REAL, IMAG}); \
29+
std::complex<T> std_res = std::tanh(std::complex<T>{REAL, IMAG}); \
30+
if (!FloatingPointEq(sycl_res.real(), std_res.real())) { \
31+
std::cout << "Real differ in tanh((" << REAL << ", " << IMAG \
32+
<< ")): " << sycl_res.real() << " != " << std_res.real() \
33+
<< std::endl; \
34+
++Failures; \
35+
} \
36+
if (!FloatingPointEq(sycl_res.imag(), std_res.imag())) { \
37+
std::cout << "Imag differ in tanh((" << REAL << ", " << IMAG \
38+
<< ")): " << sycl_res.imag() << " != " << std_res.imag() \
39+
<< std::endl; \
40+
++Failures; \
41+
} \
42+
}
43+
44+
#define CHECK_TANH_REF_RESULT(REAL, IMAG, REF_REAL, REF_IMAG, T) \
45+
{ \
46+
syclexp::complex<T> sycl_res = \
47+
syclexp::tanh(syclexp::complex<T>{REAL, IMAG}); \
48+
if (!FloatingPointEq(sycl_res.real(), T{REF_REAL})) { \
49+
std::cout << "Real differ in tanh((" << REAL << ", " << IMAG \
50+
<< ")): " << sycl_res.real() << " != " << REF_REAL \
51+
<< std::endl; \
52+
++Failures; \
53+
} \
54+
if (!FloatingPointEq(sycl_res.imag(), T{REF_IMAG})) { \
55+
std::cout << "Imag differ in tanh((" << REAL << ", " << IMAG \
56+
<< ")): " << sycl_res.imag() << " != " << REF_IMAG \
57+
<< std::endl; \
58+
++Failures; \
59+
} \
60+
}
61+
62+
int main() {
63+
// Set precision for easier debugging.
64+
std::cout << std::setprecision(10);
65+
66+
CHECK_TANH_RESULT(0, -11.0, float);
67+
CHECK_TANH_RESULT(0, -11.0, double);
68+
69+
CHECK_TANH_RESULT(32.0, std::numeric_limits<float>::infinity(), float);
70+
CHECK_TANH_RESULT(32.0, std::numeric_limits<double>::infinity(), double);
71+
72+
CHECK_TANH_RESULT(32.0, -std::numeric_limits<float>::infinity(), float);
73+
CHECK_TANH_RESULT(32.0, -std::numeric_limits<double>::infinity(), double);
74+
75+
CHECK_TANH_RESULT(std::numeric_limits<float>::max(), 0.0, float);
76+
CHECK_TANH_RESULT(std::numeric_limits<double>::max(), 0.0, double);
77+
78+
CHECK_TANH_RESULT(0.0, std::numeric_limits<float>::max(), float);
79+
CHECK_TANH_RESULT(0.0, std::numeric_limits<double>::max(), double);
80+
81+
CHECK_TANH_RESULT(0.0, 0.0, float);
82+
CHECK_TANH_RESULT(0.0, 0.0, double);
83+
84+
CHECK_TANH_RESULT(std::numeric_limits<float>::quiet_NaN(), 1.0, float);
85+
CHECK_TANH_RESULT(std::numeric_limits<double>::quiet_NaN(), 1.0, double);
86+
87+
CHECK_TANH_RESULT(std::numeric_limits<float>::quiet_NaN(),
88+
std::numeric_limits<float>::quiet_NaN(), float);
89+
CHECK_TANH_RESULT(std::numeric_limits<double>::quiet_NaN(),
90+
std::numeric_limits<double>::quiet_NaN(), double);
91+
92+
// The MSVC implementation of tanh for complex numbers does not adhere to the
93+
// following requirements set by the definition of std::tanh:
94+
// * When the input has an infinite real, then the function should return
95+
// (1, +-0).
96+
// * When the input is (NaN, 0), the result should be (NaN, 0).
97+
// Instead we check the results using reference values rather than trusting
98+
// the result of std::tanh in these cases.
99+
CHECK_TANH_REF_RESULT(std::numeric_limits<float>::infinity(), 32.0, 1.0, 0.0,
100+
float);
101+
CHECK_TANH_REF_RESULT(std::numeric_limits<double>::infinity(), 32.0, 1.0, 0.0,
102+
double);
103+
104+
CHECK_TANH_REF_RESULT(std::numeric_limits<float>::infinity(),
105+
std::numeric_limits<float>::infinity(), 1, 0.0, float);
106+
CHECK_TANH_REF_RESULT(std::numeric_limits<double>::infinity(),
107+
std::numeric_limits<double>::infinity(), 1, 0.0,
108+
double);
109+
110+
CHECK_TANH_REF_RESULT(-std::numeric_limits<float>::infinity(), 32.0, -1.0,
111+
0.0, float);
112+
CHECK_TANH_REF_RESULT(-std::numeric_limits<double>::infinity(), 32.0, -1.0,
113+
0.0, double);
114+
115+
CHECK_TANH_REF_RESULT(-std::numeric_limits<float>::infinity(),
116+
-std::numeric_limits<float>::infinity(), -1.0, 0.0,
117+
float);
118+
CHECK_TANH_REF_RESULT(-std::numeric_limits<double>::infinity(),
119+
-std::numeric_limits<double>::infinity(), -1.0, 0.0,
120+
double);
121+
122+
CHECK_TANH_REF_RESULT(std::numeric_limits<float>::infinity(),
123+
std::numeric_limits<float>::quiet_NaN(), 1.0, 0.0,
124+
float);
125+
CHECK_TANH_REF_RESULT(std::numeric_limits<double>::infinity(),
126+
std::numeric_limits<double>::quiet_NaN(), 1.0, 0.0,
127+
double);
128+
129+
CHECK_TANH_REF_RESULT(-std::numeric_limits<float>::infinity(),
130+
std::numeric_limits<float>::quiet_NaN(), -1.0, 0.0,
131+
float);
132+
CHECK_TANH_REF_RESULT(-std::numeric_limits<double>::infinity(),
133+
std::numeric_limits<double>::quiet_NaN(), -1.0, 0.0,
134+
double);
135+
136+
CHECK_TANH_REF_RESULT(std::numeric_limits<float>::quiet_NaN(), 0.0,
137+
std::numeric_limits<float>::quiet_NaN(), 0.0, float);
138+
CHECK_TANH_REF_RESULT(std::numeric_limits<double>::quiet_NaN(), 0.0,
139+
std::numeric_limits<float>::quiet_NaN(), 0.0, double);
140+
141+
return Failures;
142+
}

0 commit comments

Comments
 (0)