Skip to content

Commit 970dd33

Browse files
committed
isclose not inf vs. inf
1 parent 23fe21f commit 970dd33

File tree

2 files changed

+30
-17
lines changed

2 files changed

+30
-17
lines changed

src/array_api_extra/_lib/_funcs.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -566,9 +566,24 @@ def isclose(
566566
a_inexact = xp.isdtype(a.dtype, ("real floating", "complex floating"))
567567
b_inexact = xp.isdtype(b.dtype, ("real floating", "complex floating"))
568568
if a_inexact or b_inexact:
569-
# FIXME: use scipy's lazywhere to suppress warnings on inf
570-
out = xp.abs(a - b) <= (atol + rtol * xp.abs(b))
571-
out = xp.where(xp.isinf(a) & xp.isinf(b), xp.sign(a) == xp.sign(b), out)
569+
# prevent warnings on numpy and dask on inf - inf
570+
meta_xp = meta_namespace(a, b, xp=xp)
571+
572+
def where_inf(a: Array, b: Array) -> Array:
573+
return (
574+
meta_xp.isinf(a)
575+
& meta_xp.isinf(b)
576+
& (meta_xp.sign(a) == meta_xp.sign(b))
577+
)
578+
579+
def where_not_inf(a: Array, b: Array) -> Array:
580+
# Note: inf <= inf is True!
581+
return meta_xp.abs(a - b) <= (atol + rtol * meta_xp.abs(b))
582+
583+
out = apply_where(
584+
xp.isinf(a) | xp.isinf(b), where_inf, where_not_inf, a, b, xp=xp
585+
)
586+
572587
if equal_nan:
573588
out = xp.where(xp.isnan(a) & xp.isnan(b), xp.asarray(True), out)
574589
return out

tests/test_funcs.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -334,8 +334,7 @@ def test_xp(self, xp: ModuleType):
334334

335335
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype")
336336
class TestIsClose:
337-
# FIXME use lazywhere to avoid warnings on inf
338-
@pytest.mark.filterwarnings("ignore:invalid value encountered")
337+
@pytest.mark.parametrize("swap", [False, True])
339338
@pytest.mark.parametrize(
340339
("a", "b"),
341340
[
@@ -353,9 +352,9 @@ class TestIsClose:
353352
(float("inf"), float("inf")),
354353
(float("inf"), 100.0),
355354
(float("inf"), float("-inf")),
355+
(float("-inf"), float("-inf")),
356356
(float("nan"), float("nan")),
357-
(float("nan"), 0.0),
358-
(0.0, float("nan")),
357+
(float("nan"), 100.0),
359358
(1e6, 1e6 + 1), # True - within rtol
360359
(1e6, 1e6 + 100), # False - outside rtol
361360
(1e-6, 1.1e-6), # False - outside atol
@@ -364,19 +363,20 @@ class TestIsClose:
364363
(1e6 + 0j, 1e6 + 100j), # False - outside rtol
365364
],
366365
)
367-
def test_basic(self, a: float, b: float, xp: ModuleType):
366+
def test_basic(self, a: float, b: float, swap: bool, xp: ModuleType):
367+
if swap:
368+
b, a = a, b
368369
a_xp = xp.asarray(a)
369370
b_xp = xp.asarray(b)
370371

371372
xp_assert_equal(isclose(a_xp, b_xp), xp.asarray(np.isclose(a, b)))
372373

373374
with warnings.catch_warnings():
374375
warnings.simplefilter("ignore")
375-
r_xp = xp.asarray(np.arange(10), dtype=a_xp.dtype)
376-
ar_xp = a_xp * r_xp
377-
br_xp = b_xp * r_xp
378376
ar_np = a * np.arange(10)
379377
br_np = b * np.arange(10)
378+
ar_xp = xp.asarray(ar_np)
379+
br_xp = xp.asarray(br_np)
380380

381381
xp_assert_equal(isclose(ar_xp, br_xp), xp.asarray(np.isclose(ar_np, br_np)))
382382

@@ -392,17 +392,15 @@ def test_broadcast(self, dtype: str, xp: ModuleType):
392392

393393
xp_assert_equal(actual, expect)
394394

395-
# FIXME use lazywhere to avoid warnings on inf
396-
@pytest.mark.filterwarnings("ignore:invalid value encountered")
397395
def test_some_inf(self, xp: ModuleType):
398-
a = xp.asarray([0.0, 1.0, float("inf"), float("inf"), float("inf")])
399-
b = xp.asarray([1e-9, 1.0, float("inf"), float("-inf"), 2.0])
396+
a = xp.asarray([0.0, 1.0, xp.inf, xp.inf, xp.inf])
397+
b = xp.asarray([1e-9, 1.0, xp.inf, -xp.inf, 2.0])
400398
actual = isclose(a, b)
401399
xp_assert_equal(actual, xp.asarray([True, True, True, False, False]))
402400

403401
def test_equal_nan(self, xp: ModuleType):
404-
a = xp.asarray([float("nan"), float("nan"), 1.0])
405-
b = xp.asarray([float("nan"), 1.0, float("nan")])
402+
a = xp.asarray([xp.nan, xp.nan, 1.0])
403+
b = xp.asarray([xp.nan, 1.0, xp.nan])
406404
xp_assert_equal(isclose(a, b), xp.asarray([False, False, False]))
407405
xp_assert_equal(isclose(a, b, equal_nan=True), xp.asarray([True, False, False]))
408406

0 commit comments

Comments
 (0)