Skip to content

Commit fb66703

Browse files
Apply remarks
1 parent 584239a commit fb66703

File tree

2 files changed

+31
-22
lines changed

2 files changed

+31
-22
lines changed

dpnp/dpnp_iface_logic.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,18 @@ def _isclose_scalar_tol(a, b, rtol, atol, equal_nan):
115115
a = dpnp.astype(a, dt, casting="same_kind", copy=False)
116116
b = dpnp.astype(b, dt, casting="same_kind", copy=False)
117117

118-
# Convert complex rtol/atol to to their real parts
118+
# Convert complex rtol/atol to their real parts
119119
# to avoid pybind11 cast errors and match NumPy behavior
120120
if isinstance(rtol, complex):
121121
rtol = rtol.real
122122
if isinstance(atol, complex):
123123
atol = atol.real
124124

125+
# Convert equal_nan to bool to avoid pybind11 cast errors
126+
# and match NumPy behavior
127+
if not isinstance(equal_nan, bool):
128+
equal_nan = bool(equal_nan)
129+
125130
# pylint: disable=W0707
126131
try:
127132
a, b = dpnp.broadcast_arrays(a, b)
@@ -131,9 +136,8 @@ def _isclose_scalar_tol(a, b, rtol, atol, equal_nan):
131136
f"{a.shape} and {b.shape}"
132137
)
133138

134-
out_dtype = dpnp.bool
135139
output = dpnp.empty(
136-
a.shape, dtype=out_dtype, sycl_queue=exec_q, usm_type=usm_type
140+
a.shape, dtype=dpnp.bool, sycl_queue=exec_q, usm_type=usm_type
137141
)
138142

139143
_manager = dpu.SequentialOrderManager[exec_q]
@@ -871,7 +875,7 @@ def isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
871875
The absolute tolerance parameter.
872876
873877
Default: ``1e-08``.
874-
equal_nan : bool
878+
equal_nan : bool, optional
875879
Whether to compare ``NaNs`` as equal. If ``True``, ``NaNs`` in `a` will
876880
be considered equal to ``NaNs`` in `b` in the output array.
877881

dpnp/tests/test_logic.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -519,38 +519,40 @@ def test_infinity_sign_errors(func):
519519

520520

521521
class TestIsClose:
522-
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
522+
@pytest.mark.parametrize(
523+
"dtype", get_all_dtypes(no_bool=True, no_none=True)
524+
)
523525
@pytest.mark.parametrize(
524526
"rtol", [1e-5, dpnp.array(1e-5), dpnp.full((10,), 1e-5)]
525527
)
526528
@pytest.mark.parametrize(
527529
"atol", [1e-8, dpnp.array(1e-8), dpnp.full((10,), 1e-8)]
528530
)
529531
def test_isclose(self, dtype, rtol, atol):
530-
a = numpy.random.rand(10)
531-
b = a + numpy.random.rand(10) * 1e-8
532+
a = generate_random_numpy_array((10,), dtype=dtype)
533+
b = a + numpy.array(1e-8, dtype=dtype)
532534

533535
dpnp_a = dpnp.array(a, dtype=dtype)
534536
dpnp_b = dpnp.array(b, dtype=dtype)
535537

536538
np_res = numpy.isclose(a, b, rtol=1e-5, atol=1e-8)
537539
dpnp_res = dpnp.isclose(dpnp_a, dpnp_b, rtol=rtol, atol=atol)
538-
assert_allclose(dpnp_res, np_res)
540+
assert_equal(dpnp_res, np_res)
539541

540542
@pytest.mark.parametrize("dtype", get_complex_dtypes())
541543
@pytest.mark.parametrize("shape", [(4, 4), (16, 16), (4, 4, 4)])
542544
def test_isclose_complex(self, dtype, shape):
543545
a = generate_random_numpy_array(shape, dtype=dtype, seed_value=81)
544546
b = a.copy()
545547

546-
b = b + (1e-6 + 1e-6j)
548+
b = b + numpy.array(1e-6 + 1e-6j, dtype=dtype)
547549

548550
dpnp_a = dpnp.array(a, dtype=dtype)
549551
dpnp_b = dpnp.array(b, dtype=dtype)
550552

551553
np_res = numpy.isclose(a, b)
552554
dpnp_res = dpnp.isclose(dpnp_a, dpnp_b)
553-
assert_allclose(dpnp_res, np_res)
555+
assert_equal(dpnp_res, np_res)
554556

555557
@pytest.mark.parametrize(
556558
"rtol, atol",
@@ -568,7 +570,7 @@ def test_empty_input(self, rtol, atol):
568570

569571
np_res = numpy.isclose(a, b, rtol=1e-5, atol=1e-8)
570572
dpnp_res = dpnp.isclose(dpnp_a, dpnp_b, rtol=rtol, atol=atol)
571-
assert_allclose(dpnp_res, np_res)
573+
assert_equal(dpnp_res, np_res)
572574

573575
@pytest.mark.parametrize(
574576
"rtol, atol",
@@ -585,17 +587,17 @@ def test_input_0d(self, val, rtol, atol):
585587
# array & scalar
586588
dp_res = dpnp.isclose(dp_arr, val, rtol=rtol, atol=atol)
587589
np_res = numpy.isclose(np_arr, val, rtol=1e-5, atol=1e-8)
588-
assert_allclose(dp_res, np_res)
590+
assert_equal(dp_res, np_res)
589591

590592
# scalar & array
591593
dp_res = dpnp.isclose(val, dp_arr, rtol=rtol, atol=atol)
592594
np_res = numpy.isclose(val, np_arr, rtol=1e-5, atol=1e-8)
593-
assert_allclose(dp_res, np_res)
595+
assert_equal(dp_res, np_res)
594596

595597
# array & array
596598
dp_res = dpnp.isclose(dp_arr, dp_arr, rtol=rtol, atol=atol)
597599
np_res = numpy.isclose(np_arr, np_arr, rtol=1e-5, atol=1e-8)
598-
assert_allclose(dp_res, np_res)
600+
assert_equal(dp_res, np_res)
599601

600602
@pytest.mark.parametrize(
601603
"sh_a, sh_b",
@@ -615,7 +617,7 @@ def test_broadcast_shapes(self, sh_a, sh_b):
615617

616618
np_res = numpy.isclose(a_np, b_np)
617619
dp_res = dpnp.isclose(a_dp, b_dp)
618-
assert_allclose(dp_res, np_res)
620+
assert_equal(dp_res, np_res)
619621

620622
@pytest.mark.parametrize(
621623
"rtol, atol",
@@ -624,16 +626,19 @@ def test_broadcast_shapes(self, sh_a, sh_b):
624626
(dpnp.array(1e-5), dpnp.array(1e-8)),
625627
],
626628
)
627-
def test_equal_nan(self, rtol, atol):
629+
@pytest.mark.parametrize("equal_nan", [True, 1, "1"])
630+
def test_equal_nan(self, rtol, atol, equal_nan):
628631
a = numpy.array([numpy.nan, 1.0])
629632
b = numpy.array([numpy.nan, 1.0])
630633

631634
dp_a = dpnp.array(a)
632635
dp_b = dpnp.array(b)
633636

634-
np_res = numpy.isclose(a, b, rtol=1e-5, atol=1e-8, equal_nan=True)
635-
dp_res = dpnp.isclose(dp_a, dp_b, rtol=rtol, atol=atol, equal_nan=True)
636-
assert_allclose(dp_res, np_res)
637+
np_res = numpy.isclose(a, b, rtol=1e-5, atol=1e-8, equal_nan=equal_nan)
638+
dp_res = dpnp.isclose(
639+
dp_a, dp_b, rtol=rtol, atol=atol, equal_nan=equal_nan
640+
)
641+
assert_equal(dp_res, np_res)
637642

638643
# array-like rtol/atol support requires NumPy >= 2.0
639644
@testing.with_requires("numpy>=2.0")
@@ -650,7 +655,7 @@ def test_rtol_atol_arrays(self):
650655

651656
np_res = numpy.isclose(a, b, rtol=rtol, atol=atol)
652657
dp_res = dpnp.isclose(dp_a, dp_b, rtol=dp_rtol, atol=dp_atol)
653-
assert_allclose(dp_res, np_res)
658+
assert_equal(dp_res, np_res)
654659

655660
@pytest.mark.parametrize(
656661
"rtol, atol",
@@ -666,7 +671,7 @@ def test_rtol_atol_complex(self, rtol, atol):
666671

667672
dpnp_res = dpnp.isclose(a, b, rtol=rtol, atol=atol)
668673
np_res = numpy.isclose(a.asnumpy(), b.asnumpy(), rtol=rtol, atol=atol)
669-
assert_allclose(dpnp_res, np_res)
674+
assert_equal(dpnp_res, np_res)
670675

671676
# NEP 50: float32 vs Python float comparison requires NumPy >= 2.0
672677
@testing.with_requires("numpy>=2.0")
@@ -675,7 +680,7 @@ def test_rtol_atol_nep50(self):
675680
f32 = numpy.array(below_one, dtype="f4")
676681
dp_f32 = dpnp.array(f32)
677682

678-
assert_allclose(
683+
assert_equal(
679684
dpnp.isclose(dp_f32, below_one, rtol=0, atol=0),
680685
numpy.isclose(f32, below_one, rtol=0, atol=0),
681686
)

0 commit comments

Comments
 (0)