Skip to content

Commit 2d1dea0

Browse files
committed
[SYCL] Follow-up to exp(complex) update
Follow-up to intel#15672 This patch updates float behavior
1 parent ba04efc commit 2d1dea0

File tree

3 files changed

+22
-5
lines changed

3 files changed

+22
-5
lines changed

libdevice/fallback-complex-fp64.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,9 +154,9 @@ double __complex__ __devicelib_cexp(double __complex__ z) {
154154
return z;
155155
else /* z_imag != 0.0 */
156156
return CMPLX(NAN, NAN);
157-
} else if (__spirv_IsFinite(z_real)) {
158-
if (__spirv_IsNan(z_imag) || __spirv_IsInf(z_imag))
159-
return CMPLX(NAN, NAN);
157+
} else if (__spirv_IsFinite(z_real) &&
158+
(__spirv_IsNan(z_imag) || __spirv_IsInf(z_imag))) {
159+
return CMPLX(NAN, NAN);
160160
}
161161
double __e = __spirv_ocl_exp(z_real);
162162
double ret_real = __e * __spirv_ocl_cos(z_imag);

libdevice/fallback-complex.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,14 @@ float __complex__ __devicelib_cexpf(float __complex__ z) {
150150
z_imag = NAN;
151151
return CMPLXF(z_real, z_imag);
152152
}
153-
} else if (__spirv_IsNan(z_real) && (z_imag == 0.0f)) {
154-
return z;
153+
} else if (__spirv_IsNan(z_real)) {
154+
if (z_imag == 0.0f)
155+
return z;
156+
else /* z_imag != 0.0f */
157+
return CMPLX(NAN, NAN);
158+
} else if (__spirv_IsFinite(z_real) &&
159+
(__spirv_IsNan(z_imag) || __spirv_IsInf(z_imag))) {
160+
return CMPLX(NAN, NAN);
155161
}
156162
float __e = __spirv_ocl_exp(z_real);
157163
float ret_real = __e * __spirv_ocl_cos(z_imag);
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
// This test checks edge cases handling for std::exp(std::complex<float>) used
2+
// in SYCL kernels.
3+
//
4+
// UNSUPPORTED: hip || cuda
5+
//
6+
// RUN: %{build} -o %t.out
7+
// RUN: %{run} %t.out
8+
9+
#include "exp-std-complex-edge-cases.hpp"
10+
11+
int main() { return test<float>(); }

0 commit comments

Comments
 (0)