Skip to content

Commit 7a8ae1a

Browse files
committed
[SYCL] Fix complex tanh
The current implementation of the tanh function for the SYCL complex extension did not always return values precise enough, neither on device nor on host. This was caused by accumulated error from a call to cos. This commit changes the implementation of tanh to use a different way of calculating the result and tests it against std::tanh for complex numbers. Signed-off-by: Larsen, Steffen <[email protected]>
1 parent 446a453 commit 7a8ae1a

File tree

2 files changed

+93
-11
lines changed

2 files changed

+93
-11
lines changed

sycl/include/sycl/ext/oneapi/experimental/complex/detail/complex_math.hpp

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -495,20 +495,18 @@ __DPCPP_SYCL_EXTERNAL _SYCL_EXT_CPLX_INLINE_VISIBILITY
495495
tanh(const complex<_Tp> &__x) {
496496
if (sycl::isinf(__x.real())) {
497497
if (!sycl::isfinite(__x.imag()))
498-
return complex<_Tp>(sycl::copysign(_Tp(1), __x.real()), _Tp(0));
499-
return complex<_Tp>(sycl::copysign(_Tp(1), __x.real()),
500-
sycl::copysign(_Tp(0), sycl::sin(_Tp(2) * __x.imag())));
498+
return complex<_Tp>(_Tp(1), _Tp(0));
499+
return complex<_Tp>(_Tp(1), copysign(_Tp(0), sin(_Tp(2) * __x.imag())));
501500
}
502501
if (sycl::isnan(__x.real()) && __x.imag() == 0)
503502
return __x;
504-
_Tp __2r(_Tp(2) * __x.real());
505-
_Tp __2i(_Tp(2) * __x.imag());
506-
_Tp __d(sycl::cosh(__2r) + sycl::cos(__2i));
507-
_Tp __2rsh(sycl::sinh(__2r));
508-
if (sycl::isinf(__2rsh) && sycl::isinf(__d))
509-
return complex<_Tp>(__2rsh > _Tp(0) ? _Tp(1) : _Tp(-1),
510-
__2i > _Tp(0) ? _Tp(0) : _Tp(-0.));
511-
return complex<_Tp>(__2rsh / __d, sycl::sin(__2i) / __d);
503+
complex<_Tp> sinh_x = sinh(__x);
504+
complex<_Tp> cosh_x = cosh(__x);
505+
if (sycl::isinf(sinh_x.real()) && sycl::isinf(cosh_x.real()))
506+
return complex<_Tp>(sinh_x.real() * cosh_x.real() > _Tp(0) ? _Tp(1)
507+
: _Tp(-1),
508+
_Tp(2) * __x.imag() > _Tp(0) ? _Tp(0) : _Tp(-0.));
509+
return sinh_x / cosh_x;
512510
}
513511

514512
// asin
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
// RUN: %clangxx -fsycl %s -o %t.out
2+
// RUN: %t.out
3+
4+
// Checks the results of tanh on certain complex numbers.
5+
6+
#define SYCL_EXT_ONEAPI_COMPLEX
7+
8+
#include <sycl/ext/oneapi/experimental/complex/complex.hpp>
9+
#include <sycl/sycl.hpp>
10+
11+
#include <complex>
12+
#include <limits>
13+
14+
namespace syclexp = sycl::ext::oneapi::experimental;
15+
16+
int Failures = 0;
17+
18+
template <typename T> bool FloatingPointEq(T LHS, T RHS) {
19+
if (std::isnan(LHS))
20+
return std::isnan(RHS);
21+
return LHS == RHS;
22+
}
23+
24+
#define CHECK_TANH_RESULT(REAL, IMAG, T) \
25+
{ \
26+
syclexp::complex<T> sycl_res = \
27+
syclexp::tanh(syclexp::complex<T>{REAL, IMAG}); \
28+
std::complex<T> std_res = std::tanh(std::complex<T>{REAL, IMAG}); \
29+
if (!FloatingPointEq(sycl_res.real(), std_res.real())) { \
30+
std::cout << "Real differ in tanh((" << REAL << ", " << IMAG \
31+
<< ")): " << sycl_res.real() << " != " << std_res.real() \
32+
<< std::endl; \
33+
++Failures; \
34+
} \
35+
if (!FloatingPointEq(sycl_res.imag(), std_res.imag())) { \
36+
std::cout << "Imag differ in tanh((" << REAL << ", " << IMAG \
37+
<< ")): " << sycl_res.imag() << " != " << std_res.imag() \
38+
<< std::endl; \
39+
++Failures; \
40+
} \
41+
}
42+
43+
int main() {
44+
CHECK_TANH_RESULT(0, -11.0, float);
45+
CHECK_TANH_RESULT(0, -11.0, double);
46+
47+
CHECK_TANH_RESULT(std::numeric_limits<float>::infinity(), 32.0, float);
48+
CHECK_TANH_RESULT(std::numeric_limits<double>::infinity(), 32.0, double);
49+
50+
CHECK_TANH_RESULT(32.0, std::numeric_limits<float>::infinity(), float);
51+
CHECK_TANH_RESULT(32.0, std::numeric_limits<double>::infinity(), double);
52+
53+
CHECK_TANH_RESULT(std::numeric_limits<float>::infinity(),
54+
std::numeric_limits<float>::infinity(), float);
55+
CHECK_TANH_RESULT(std::numeric_limits<double>::infinity(),
56+
std::numeric_limits<double>::infinity(), double);
57+
58+
CHECK_TANH_RESULT(std::numeric_limits<float>::max(), 0.0, float);
59+
CHECK_TANH_RESULT(std::numeric_limits<double>::max(), 0.0, double);
60+
61+
CHECK_TANH_RESULT(0.0, std::numeric_limits<float>::max(), float);
62+
CHECK_TANH_RESULT(0.0, std::numeric_limits<double>::max(), double);
63+
64+
CHECK_TANH_RESULT(0.0, 0.0, float);
65+
CHECK_TANH_RESULT(0.0, 0.0, double);
66+
67+
CHECK_TANH_RESULT(std::numeric_limits<float>::infinity(),
68+
std::numeric_limits<float>::quiet_NaN(), float);
69+
CHECK_TANH_RESULT(std::numeric_limits<double>::infinity(),
70+
std::numeric_limits<float>::quiet_NaN(), double);
71+
72+
CHECK_TANH_RESULT(std::numeric_limits<float>::quiet_NaN(), 0.0, float);
73+
CHECK_TANH_RESULT(std::numeric_limits<double>::quiet_NaN(), 0.0, double);
74+
75+
CHECK_TANH_RESULT(std::numeric_limits<float>::quiet_NaN(), 1.0, float);
76+
CHECK_TANH_RESULT(std::numeric_limits<double>::quiet_NaN(), 1.0, double);
77+
78+
CHECK_TANH_RESULT(std::numeric_limits<float>::quiet_NaN(),
79+
std::numeric_limits<float>::quiet_NaN(), float);
80+
CHECK_TANH_RESULT(std::numeric_limits<double>::quiet_NaN(),
81+
std::numeric_limits<double>::quiet_NaN(), double);
82+
83+
return Failures;
84+
}

0 commit comments

Comments
 (0)