Skip to content

Commit 9949d12

Browse files
committed
[SYCL] Fix complex tanh
The current implementation of the tanh function for the SYCL complex extension did not always return values precise enough, neither on device nor on host. This was caused by accumulated error from a call to cos. This commit changes the implementation of tanh to use a different way of calculating the result and tests it against std::tanh for complex numbers. Signed-off-by: Larsen, Steffen <[email protected]>
1 parent d67098c commit 9949d12

File tree

2 files changed

+111
-13
lines changed

2 files changed

+111
-13
lines changed

sycl/include/sycl/ext/oneapi/experimental/complex/detail/complex_math.hpp

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -493,22 +493,20 @@ template <class _Tp>
493493
__DPCPP_SYCL_EXTERNAL _SYCL_EXT_CPLX_INLINE_VISIBILITY
494494
typename std::enable_if_t<is_genfloat<_Tp>::value, complex<_Tp>>
495495
tanh(const complex<_Tp> &__x) {
496-
if (sycl::isinf(__x.real())) {
497-
if (!sycl::isfinite(__x.imag()))
498-
return complex<_Tp>(sycl::copysign(_Tp(1), __x.real()), _Tp(0));
496+
if (sycl::isinf(__x.real()))
499497
return complex<_Tp>(sycl::copysign(_Tp(1), __x.real()),
500-
sycl::copysign(_Tp(0), sycl::sin(_Tp(2) * __x.imag())));
501-
}
498+
sycl::copysign(_Tp(0), sycl::isfinite(__x.imag())
499+
? sin(_Tp(2) * __x.imag())
500+
: _Tp(1)));
502501
if (sycl::isnan(__x.real()) && __x.imag() == 0)
503502
return __x;
504-
_Tp __2r(_Tp(2) * __x.real());
505-
_Tp __2i(_Tp(2) * __x.imag());
506-
_Tp __d(sycl::cosh(__2r) + sycl::cos(__2i));
507-
_Tp __2rsh(sycl::sinh(__2r));
508-
if (sycl::isinf(__2rsh) && sycl::isinf(__d))
509-
return complex<_Tp>(__2rsh > _Tp(0) ? _Tp(1) : _Tp(-1),
510-
__2i > _Tp(0) ? _Tp(0) : _Tp(-0.));
511-
return complex<_Tp>(__2rsh / __d, sycl::sin(__2i) / __d);
503+
complex<_Tp> sinh_x = sinh(__x);
504+
complex<_Tp> cosh_x = cosh(__x);
505+
if (sycl::isinf(sinh_x.real()) && sycl::isinf(cosh_x.real()))
506+
return complex<_Tp>(sinh_x.real() * cosh_x.real() > _Tp(0) ? _Tp(1)
507+
: _Tp(-1),
508+
__x.imag() > _Tp(0) ? _Tp(0) : _Tp(-0.));
509+
return sinh_x / cosh_x;
512510
}
513511

514512
// asin
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
// RUN: %clangxx -fsycl %s -o %t.out
2+
// RUN: %t.out
3+
4+
// Checks the results of tanh on certain complex numbers.
5+
6+
#define SYCL_EXT_ONEAPI_COMPLEX
7+
8+
#include <sycl/ext/oneapi/experimental/complex/complex.hpp>
9+
#include <sycl/sycl.hpp>
10+
11+
#include <complex>
12+
#include <limits>
13+
14+
namespace syclexp = sycl::ext::oneapi::experimental;
15+
16+
int Failures = 0;
17+
18+
template <typename T> bool FloatingPointEq(T LHS, T RHS) {
19+
if (std::isnan(LHS))
20+
return std::isnan(RHS);
21+
return LHS == RHS;
22+
}
23+
24+
#define CHECK_TANH_RESULT(REAL, IMAG, T) \
25+
{ \
26+
syclexp::complex<T> sycl_res = \
27+
syclexp::tanh(syclexp::complex<T>{REAL, IMAG}); \
28+
std::complex<T> std_res = std::tanh(std::complex<T>{REAL, IMAG}); \
29+
if (!FloatingPointEq(sycl_res.real(), std_res.real())) { \
30+
std::cout << "Real differ in tanh((" << REAL << ", " << IMAG \
31+
<< ")): " << sycl_res.real() << " != " << std_res.real() \
32+
<< std::endl; \
33+
++Failures; \
34+
} \
35+
if (!FloatingPointEq(sycl_res.imag(), std_res.imag())) { \
36+
std::cout << "Imag differ in tanh((" << REAL << ", " << IMAG \
37+
<< ")): " << sycl_res.imag() << " != " << std_res.imag() \
38+
<< std::endl; \
39+
++Failures; \
40+
} \
41+
}
42+
43+
int main() {
44+
CHECK_TANH_RESULT(0, -11.0, float);
45+
CHECK_TANH_RESULT(0, -11.0, double);
46+
47+
CHECK_TANH_RESULT(std::numeric_limits<float>::infinity(), 32.0, float);
48+
CHECK_TANH_RESULT(std::numeric_limits<double>::infinity(), 32.0, double);
49+
50+
CHECK_TANH_RESULT(32.0, std::numeric_limits<float>::infinity(), float);
51+
CHECK_TANH_RESULT(32.0, std::numeric_limits<double>::infinity(), double);
52+
53+
CHECK_TANH_RESULT(std::numeric_limits<float>::infinity(),
54+
std::numeric_limits<float>::infinity(), float);
55+
CHECK_TANH_RESULT(std::numeric_limits<double>::infinity(),
56+
std::numeric_limits<double>::infinity(), double);
57+
58+
CHECK_TANH_RESULT(-std::numeric_limits<float>::infinity(), 32.0, float);
59+
CHECK_TANH_RESULT(-std::numeric_limits<double>::infinity(), 32.0, double);
60+
61+
CHECK_TANH_RESULT(32.0, -std::numeric_limits<float>::infinity(), float);
62+
CHECK_TANH_RESULT(32.0, -std::numeric_limits<double>::infinity(), double);
63+
64+
CHECK_TANH_RESULT(-std::numeric_limits<float>::infinity(),
65+
-std::numeric_limits<float>::infinity(), float);
66+
CHECK_TANH_RESULT(-std::numeric_limits<double>::infinity(),
67+
-std::numeric_limits<double>::infinity(), double);
68+
69+
CHECK_TANH_RESULT(std::numeric_limits<float>::max(), 0.0, float);
70+
CHECK_TANH_RESULT(std::numeric_limits<double>::max(), 0.0, double);
71+
72+
CHECK_TANH_RESULT(0.0, std::numeric_limits<float>::max(), float);
73+
CHECK_TANH_RESULT(0.0, std::numeric_limits<double>::max(), double);
74+
75+
CHECK_TANH_RESULT(0.0, 0.0, float);
76+
CHECK_TANH_RESULT(0.0, 0.0, double);
77+
78+
CHECK_TANH_RESULT(std::numeric_limits<float>::infinity(),
79+
std::numeric_limits<float>::quiet_NaN(), float);
80+
CHECK_TANH_RESULT(std::numeric_limits<double>::infinity(),
81+
std::numeric_limits<double>::quiet_NaN(), double);
82+
83+
CHECK_TANH_RESULT(-std::numeric_limits<float>::infinity(),
84+
std::numeric_limits<float>::quiet_NaN(), float);
85+
CHECK_TANH_RESULT(-std::numeric_limits<double>::infinity(),
86+
std::numeric_limits<double>::quiet_NaN(), double);
87+
88+
CHECK_TANH_RESULT(std::numeric_limits<float>::quiet_NaN(), 0.0, float);
89+
CHECK_TANH_RESULT(std::numeric_limits<double>::quiet_NaN(), 0.0, double);
90+
91+
CHECK_TANH_RESULT(std::numeric_limits<float>::quiet_NaN(), 1.0, float);
92+
CHECK_TANH_RESULT(std::numeric_limits<double>::quiet_NaN(), 1.0, double);
93+
94+
CHECK_TANH_RESULT(std::numeric_limits<float>::quiet_NaN(),
95+
std::numeric_limits<float>::quiet_NaN(), float);
96+
CHECK_TANH_RESULT(std::numeric_limits<double>::quiet_NaN(),
97+
std::numeric_limits<double>::quiet_NaN(), double);
98+
99+
return Failures;
100+
}

0 commit comments

Comments
 (0)