2525
2626#pragma once
2727
28+ #define SYCL_EXT_ONEAPI_COMPLEX
29+ #if __has_include(<sycl/ext/oneapi/experimental/sycl_complex.hpp>)
30+ #include < sycl/ext/oneapi/experimental/sycl_complex.hpp>
31+ #else
32+ #include < sycl/ext/oneapi/experimental/complex/complex.hpp>
33+ #endif
34+
2835#include < complex>
2936#include < cstddef>
3037#include < vector>
3946
4047namespace dpnp ::kernels::isclose
4148{
49+ namespace exprm_ns = sycl::ext::oneapi::experimental;
4250
4351template <typename T>
4452inline bool isclose (const T a,
@@ -47,14 +55,39 @@ inline bool isclose(const T a,
4755 const T atol,
4856 const bool equal_nan)
4957{
50- if (sycl::isnan (a) || sycl::isnan (b)) {
51- // static cast<T>?
52- return equal_nan && sycl::isnan (a) && sycl::isnan (b);
58+ if (sycl::isfinite (a) && sycl::isfinite (b)) {
59+ return sycl::fabs (a - b) <= atol + rtol * sycl::fabs (b);
60+ }
61+
62+ if (sycl::isnan (a) && sycl::isnan (b)) {
63+ return equal_nan;
64+ }
65+
66+ return a == b;
67+ }
68+
69+ template <typename T>
70+ inline bool isclose (const std::complex <T> a,
71+ const std::complex <T> b,
72+ const T rtol,
73+ const T atol,
74+ const bool equal_nan)
75+ {
76+ const bool a_finite = sycl::isfinite (a.real ()) && sycl::isfinite (a.imag ());
77+ const bool b_finite = sycl::isfinite (b.real ()) && sycl::isfinite (b.imag ());
78+
79+ if (a_finite && b_finite) {
80+ return exprm_ns::abs (exprm_ns::complex <T>(a - b)) <=
81+ atol + rtol * exprm_ns::abs (exprm_ns::complex <T>(b));
5382 }
54- if (sycl::isinf (a) || sycl::isinf (b)) {
55- return a == b;
83+
84+ if (sycl::isnan (a.real ()) && sycl::isnan (a.imag ()) &&
85+ sycl::isnan (b.real ()) && sycl::isnan (b.imag ()))
86+ {
87+ return equal_nan;
5688 }
57- return sycl::fabs (a - b) <= atol + rtol * sycl::fabs (b);
89+
90+ return a == b;
5891}
5992
6093template <typename T,
@@ -95,18 +128,8 @@ struct IsCloseStridedScalarFunctor
95128 const dpctl::tensor::ssize_t &out_offset =
96129 three_offsets_.get_third_offset ();
97130
98- using dpctl::tensor::type_utils::is_complex_v;
99- if constexpr (is_complex_v<T>) {
100- T z_a = a_[inp1_offset];
101- T z_b = b_[inp2_offset];
102- bool x = isclose (z_a.real (), z_b.real (), rtol_, atol_, equal_nan_);
103- bool y = isclose (z_a.imag (), z_b.imag (), rtol_, atol_, equal_nan_);
104- out_[out_offset] = x && y;
105- }
106- else {
107- out_[out_offset] = isclose (a_[inp1_offset], b_[inp2_offset], rtol_,
108- atol_, equal_nan_);
109- }
131+ out_[out_offset] =
132+ isclose (a_[inp1_offset], b_[inp2_offset], rtol_, atol_, equal_nan_);
110133 }
111134};
112135
@@ -201,19 +224,8 @@ struct IsCloseContigScalarFunctor
201224 (gid / sgSize) * (elems_per_sg - sgSize) + gid;
202225 const std::size_t end = std::min (nelems_, start + elems_per_sg);
203226 for (std::size_t offset = start; offset < end; offset += sgSize) {
204- if constexpr (is_complex_v<T>) {
205- T z_a = a_[offset];
206- T z_b = b_[offset];
207- bool x = isclose (z_a.real (), z_b.real (), rtol_, atol_,
208- equal_nan_);
209- bool y = isclose (z_a.imag (), z_b.imag (), rtol_, atol_,
210- equal_nan_);
211- out_[offset] = x && y;
212- }
213- else {
214- out_[offset] = isclose (a_[offset], b_[offset], rtol_, atol_,
215- equal_nan_);
216- }
227+ out_[offset] =
228+ isclose (a_[offset], b_[offset], rtol_, atol_, equal_nan_);
217229 }
218230 }
219231 }
0 commit comments