Skip to content

Commit c5da321

Browse files
Update TestIsClose
1 parent 665f043 commit c5da321

File tree

1 file changed

+22
-5
lines changed

1 file changed

+22
-5
lines changed

dpnp/tests/test_logic.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
import dpnp
1111

1212
from .helper import (
13+
generate_random_numpy_array,
1314
get_all_dtypes,
15+
get_complex_dtypes,
1416
get_float_complex_dtypes,
1517
get_float_dtypes,
1618
get_integer_float_dtypes,
@@ -554,6 +556,21 @@ def test_isclose(self, dtype, rtol, atol):
554556
dpnp_res = dpnp.isclose(dpnp_a, dpnp_b, rtol=rtol, atol=atol)
555557
assert_allclose(dpnp_res, np_res)
556558

559+
@pytest.mark.parametrize("dtype", get_complex_dtypes())
560+
@pytest.mark.parametrize("shape", [(4, 4), (16, 16), (4, 4, 4)])
561+
def test_isclose_complex(self, dtype, shape):
562+
a = generate_random_numpy_array(shape, dtype=dtype, seed_value=81)
563+
b = a.copy()
564+
565+
b = b + (1e-6 + 1e-6j)
566+
567+
dpnp_a = dpnp.array(a, dtype=dtype)
568+
dpnp_b = dpnp.array(b, dtype=dtype)
569+
570+
np_res = numpy.isclose(a, b)
571+
dpnp_res = dpnp.isclose(dpnp_a, dpnp_b)
572+
assert_allclose(dpnp_res, np_res)
573+
557574
@pytest.mark.parametrize(
558575
"sh_a, sh_b",
559576
[
@@ -603,14 +620,14 @@ def test_rtol_atol_arrays(self):
603620
@pytest.mark.parametrize(
604621
"rtol, atol",
605622
[
606-
(1e-05 + 1j, 1e-08),
607-
(1e-05, 1e-08 + 1j),
608-
(1e-05 + 1j, 1e-08 + 1j),
623+
(0 + 1e-5j, 1e-08),
624+
(1e-05, 0 + 1e-8j),
625+
(0 + 1e-5j, 0 + 1e-8j),
609626
],
610627
)
611628
def test_rtol_atol_complex(self, rtol, atol):
612-
a = dpnp.array([1.0, 2.0])
613-
b = dpnp.array([1.0, 2.0 + 1e-7])
629+
a = dpnp.array([1.0, 1.0])
630+
b = dpnp.array([1.0, 1.0 + 1e-6])
614631

615632
dpnp_res = dpnp.isclose(a, b, rtol=rtol, atol=atol)
616633
np_res = numpy.isclose(a.asnumpy(), b.asnumpy(), rtol=rtol, atol=atol)

0 commit comments

Comments
 (0)