Skip to content

Commit ddd2ec0

Browse files
Fix incorrect complex isclose logic and add proper overload
1 parent 68b3878 commit ddd2ec0

File tree

1 file changed

+43
-31
lines changed

1 file changed

+43
-31
lines changed

dpnp/backend/kernels/elementwise_functions/isclose.hpp

Lines changed: 43 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,13 @@
2525

2626
#pragma once
2727

28+
#define SYCL_EXT_ONEAPI_COMPLEX
29+
#if __has_include(<sycl/ext/oneapi/experimental/sycl_complex.hpp>)
30+
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
31+
#else
32+
#include <sycl/ext/oneapi/experimental/complex/complex.hpp>
33+
#endif
34+
2835
#include <complex>
2936
#include <cstddef>
3037
#include <vector>
@@ -39,6 +46,7 @@
3946

4047
namespace dpnp::kernels::isclose
4148
{
49+
namespace exprm_ns = sycl::ext::oneapi::experimental;
4250

4351
template <typename T>
4452
inline bool isclose(const T a,
@@ -47,14 +55,39 @@ inline bool isclose(const T a,
4755
const T atol,
4856
const bool equal_nan)
4957
{
50-
if (sycl::isnan(a) || sycl::isnan(b)) {
51-
// static cast<T>?
52-
return equal_nan && sycl::isnan(a) && sycl::isnan(b);
58+
if (sycl::isfinite(a) && sycl::isfinite(b)) {
59+
return sycl::fabs(a - b) <= atol + rtol * sycl::fabs(b);
60+
}
61+
62+
if (sycl::isnan(a) && sycl::isnan(b)) {
63+
return equal_nan;
64+
}
65+
66+
return a == b;
67+
}
68+
69+
template <typename T>
70+
inline bool isclose(const std::complex<T> a,
71+
const std::complex<T> b,
72+
const T rtol,
73+
const T atol,
74+
const bool equal_nan)
75+
{
76+
const bool a_finite = sycl::isfinite(a.real()) && sycl::isfinite(a.imag());
77+
const bool b_finite = sycl::isfinite(b.real()) && sycl::isfinite(b.imag());
78+
79+
if (a_finite && b_finite) {
80+
return exprm_ns::abs(exprm_ns::complex<T>(a - b)) <=
81+
atol + rtol * exprm_ns::abs(exprm_ns::complex<T>(b));
5382
}
54-
if (sycl::isinf(a) || sycl::isinf(b)) {
55-
return a == b;
83+
84+
if (sycl::isnan(a.real()) && sycl::isnan(a.imag()) &&
85+
sycl::isnan(b.real()) && sycl::isnan(b.imag()))
86+
{
87+
return equal_nan;
5688
}
57-
return sycl::fabs(a - b) <= atol + rtol * sycl::fabs(b);
89+
90+
return a == b;
5891
}
5992

6093
template <typename T,
@@ -95,18 +128,8 @@ struct IsCloseStridedScalarFunctor
95128
const dpctl::tensor::ssize_t &out_offset =
96129
three_offsets_.get_third_offset();
97130

98-
using dpctl::tensor::type_utils::is_complex_v;
99-
if constexpr (is_complex_v<T>) {
100-
T z_a = a_[inp1_offset];
101-
T z_b = b_[inp2_offset];
102-
bool x = isclose(z_a.real(), z_b.real(), rtol_, atol_, equal_nan_);
103-
bool y = isclose(z_a.imag(), z_b.imag(), rtol_, atol_, equal_nan_);
104-
out_[out_offset] = x && y;
105-
}
106-
else {
107-
out_[out_offset] = isclose(a_[inp1_offset], b_[inp2_offset], rtol_,
108-
atol_, equal_nan_);
109-
}
131+
out_[out_offset] =
132+
isclose(a_[inp1_offset], b_[inp2_offset], rtol_, atol_, equal_nan_);
110133
}
111134
};
112135

@@ -201,19 +224,8 @@ struct IsCloseContigScalarFunctor
201224
(gid / sgSize) * (elems_per_sg - sgSize) + gid;
202225
const std::size_t end = std::min(nelems_, start + elems_per_sg);
203226
for (std::size_t offset = start; offset < end; offset += sgSize) {
204-
if constexpr (is_complex_v<T>) {
205-
T z_a = a_[offset];
206-
T z_b = b_[offset];
207-
bool x = isclose(z_a.real(), z_b.real(), rtol_, atol_,
208-
equal_nan_);
209-
bool y = isclose(z_a.imag(), z_b.imag(), rtol_, atol_,
210-
equal_nan_);
211-
out_[offset] = x && y;
212-
}
213-
else {
214-
out_[offset] = isclose(a_[offset], b_[offset], rtol_, atol_,
215-
equal_nan_);
216-
}
227+
out_[offset] =
228+
isclose(a_[offset], b_[offset], rtol_, atol_, equal_nan_);
217229
}
218230
}
219231
}

0 commit comments

Comments
 (0)