Skip to content

Commit 919de5f

Browse files
committed
ENH: new functions isclose and allclose
1 parent 48fb66a commit 919de5f

File tree

7 files changed

+289
-4
lines changed

7 files changed

+289
-4
lines changed

docs/api-reference.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66
:nosignatures:
77
:toctree: generated
88
9+
allclose
910
at
1011
atleast_nd
1112
cov
1213
create_diagonal
1314
expand_dims
15+
isclose
1416
kron
1517
nunique
1618
pad

pixi.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ checks = [
293293
"all", # report on all checks, except the below
294294
"EX01", # most docstrings do not need an example
295295
"SA01", # data-apis/array-api-extra#87
296+
"SA04", # Missing description for See Also cross-reference
296297
"ES01", # most docstrings do not need an extended summary
297298
]
298299
exclude = [ # don't report on objects that match any of these regex

src/array_api_extra/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Extra array functions built on top of the array API standard."""
22

3-
from ._delegation import pad
3+
from ._delegation import allclose, isclose, pad
44
from ._lib._at import at
55
from ._lib._funcs import (
66
atleast_nd,
@@ -18,11 +18,13 @@
1818
# pylint: disable=duplicate-code
1919
__all__ = [
2020
"__version__",
21+
"allclose",
2122
"at",
2223
"atleast_nd",
2324
"cov",
2425
"create_diagonal",
2526
"expand_dims",
27+
"isclose",
2628
"kron",
2729
"nunique",
2830
"pad",

src/array_api_extra/_delegation.py

Lines changed: 139 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from ._lib._utils._compat import array_namespace
88
from ._lib._utils._typing import Array
99

10-
__all__ = ["pad"]
10+
__all__ = ["allclose", "isclose", "pad"]
1111

1212

1313
def _delegate(xp: ModuleType, *backends: Backend) -> bool:
@@ -29,6 +29,144 @@ def _delegate(xp: ModuleType, *backends: Backend) -> bool:
2929
return any(backend.is_namespace(xp) for backend in backends)
3030

3131

32+
def allclose(
33+
a: Array,
34+
b: Array,
35+
*,
36+
rtol: float = 1e-05,
37+
atol: float = 1e-08,
38+
equal_nan: bool = False,
39+
xp: ModuleType | None = None,
40+
) -> Array:
41+
"""
42+
Return True if two arrays are element-wise equal within a tolerance.
43+
44+
This is a simple convenience reduction around `isclose`.
45+
46+
Parameters
47+
----------
48+
a, b : Array
49+
Input arrays to compare.
50+
rtol : array_like, optional
51+
The relative tolerance parameter (see Notes).
52+
atol : array_like, optional
53+
The absolute tolerance parameter (see Notes).
54+
equal_nan : bool, optional
55+
Whether to compare NaN's as equal. If True, NaN's in `a` will be considered
56+
equal to NaN's in `b` in the output array.
57+
xp : array_namespace, optional
58+
The standard-compatible namespace for `a` and `b`. Default: infer.
59+
60+
Returns
61+
-------
62+
Array
63+
A 0-dimensional boolean array, containing `True` if `a` is close to `b`, and
64+
`False` otherwise.
65+
66+
See Also
67+
--------
68+
isclose
69+
math.isclose
70+
71+
Notes
72+
-----
73+
If `xp` is a lazy backend (e.g. Dask, JAX), you may not be able to test the result
74+
contents with `bool(allclose(a, b))` or `if allclose(a, b): ...`.
75+
"""
76+
xp = array_namespace(a, b) if xp is None else xp
77+
return xp.all(isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan, xp=xp))
78+
79+
80+
def isclose(
81+
a: Array,
82+
b: Array,
83+
*,
84+
rtol: float = 1e-05,
85+
atol: float = 1e-08,
86+
equal_nan: bool = False,
87+
xp: ModuleType | None = None,
88+
) -> Array:
89+
"""
90+
Return a boolean array where two arrays are element-wise equal within a tolerance.
91+
92+
The tolerance values are positive, typically very small numbers. The relative
93+
difference (rtol * abs(b)) and the absolute difference atol are added together to
94+
compare against the absolute difference between a and b.
95+
96+
NaNs are treated as equal if they are in the same place and if equal_nan=True. Infs
97+
are treated as equal if they are in the same place and of the same sign in both
98+
arrays.
99+
100+
Parameters
101+
----------
102+
a, b : Array
103+
Input arrays to compare.
104+
rtol : array_like, optional
105+
The relative tolerance parameter (see Notes).
106+
atol : array_like, optional
107+
The absolute tolerance parameter (see Notes).
108+
equal_nan : bool, optional
109+
Whether to compare NaN's as equal. If True, NaN's in `a` will be considered
110+
equal to NaN's in `b` in the output array.
111+
xp : array_namespace, optional
112+
The standard-compatible namespace for `a` and `b`. Default: infer.
113+
114+
Returns
115+
-------
116+
Array
117+
A boolean array of shape broadcasted from `a` and `b`, containing `True` where
118+
``a`` is close to ``b``, and `False` otherwise.
119+
120+
Warnings
121+
--------
122+
The default atol is not appropriate for comparing numbers with magnitudes much
123+
smaller than one ) (see notes).
124+
125+
See Also
126+
--------
127+
allclose
128+
math.isclose
129+
130+
Notes
131+
-----
132+
For finite values, `isclose` uses the following equation to test whether two
133+
floating point values are equivalent::
134+
135+
absolute(a - b) <= (atol + rtol * absolute(b))
136+
137+
Unlike the built-in `math.isclose`, the above equation is not symmetric in a and b,
138+
so that `isclose(a, b)` might be different from `isclose(b, a)` in some rare
139+
cases.
140+
141+
The default value of `atol` is not appropriate when the reference value `b` has
142+
magnitude smaller than one. For example, it is unlikely that `a = 1e-9` and
143+
`b = 2e-9` should be considered "close", yet `isclose(1e-9, 2e-9)` is `True` with
144+
default settings. Be sure to select atol for the use case at hand, especially for
145+
defining the threshold below which a non-zero value in `a` will be considered
146+
"close" to a very small or zero value in `b`.
147+
148+
The comparison of `a` and `b` uses standard broadcasting, which means that `a` and
149+
`b` need not have the same shape in order for `isclose(a, b)` to evaluate to
150+
`True`.
151+
152+
`isclose` is not defined for non-numeric data types. `bool` is considered a numeric
153+
data-type for this purpose.
154+
"""
155+
xp = array_namespace(a, b) if xp is None else xp
156+
157+
if _delegate(
158+
xp,
159+
Backend.NUMPY,
160+
Backend.CUPY,
161+
Backend.DASK,
162+
Backend.JAX,
163+
Backend.TORCH,
164+
):
165+
return xp.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
166+
167+
return _funcs.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
168+
169+
32170
def pad(
33171
x: Array,
34172
pad_width: int | tuple[int, int] | list[tuple[int, int]],

src/array_api_extra/_lib/_funcs.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,36 @@ def expand_dims(
304304
return a
305305

306306

307+
def isclose(
308+
a: Array,
309+
b: Array,
310+
*,
311+
rtol: float = 1e-05,
312+
atol: float = 1e-08,
313+
equal_nan: bool = False,
314+
xp: ModuleType | None = None,
315+
) -> Array: # numpydoc ignore=PR01,RT01
316+
"""See docstring in array_api_extra._delegation."""
317+
xp = array_namespace(a, b) if xp is None else xp
318+
319+
a_inexact = xp.isdtype(a.dtype, ("real floating", "complex floating"))
320+
b_inexact = xp.isdtype(b.dtype, ("real floating", "complex floating"))
321+
if a_inexact or b_inexact:
322+
# FIXME: use scipy's lazywhere to suppress warnings on inf
323+
out = xp.abs(a - b) <= (atol + rtol * xp.abs(b))
324+
out = xp.where(xp.isinf(a) & xp.isinf(b), xp.sign(a) == xp.sign(b), out)
325+
if equal_nan:
326+
out = xp.where(xp.isnan(a) & xp.isnan(b), xp.asarray(True), out)
327+
return out
328+
329+
# integer types
330+
atol = int(atol)
331+
if rtol == 0:
332+
return xp.abs(a - b) <= atol
333+
nrtol = int(1.0 / rtol)
334+
return xp.abs(a - b) <= (atol + xp.abs(b) // nrtol)
335+
336+
307337
def kron(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array:
308338
"""
309339
Kronecker product of two arrays.

tests/test_funcs.py

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66
import pytest
77

88
from array_api_extra import (
9+
allclose,
910
at,
1011
atleast_nd,
1112
cov,
1213
create_diagonal,
1314
expand_dims,
15+
isclose,
1416
kron,
1517
nunique,
1618
pad,
@@ -23,7 +25,7 @@
2325
from array_api_extra._lib._utils._typing import Array, Device
2426

2527
# some xp backends are untyped
26-
# mypy: disable-error-code=no-untyped-usage
28+
# mypy: disable-error-code=no-untyped-def
2729

2830

2931
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no expand_dims")
@@ -252,6 +254,116 @@ def test_xp(self, xp: ModuleType):
252254
assert y.shape == (1, 1, 1, 3)
253255

254256

257+
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no isdtype")
258+
class TestIsClose:
259+
# FIXME use lazywhere to avoid warnings on inf
260+
@pytest.mark.filterwarnings("ignore:invalid value encountered")
261+
@pytest.mark.parametrize(
262+
("a", "b"),
263+
[
264+
(0.0, 0.0),
265+
(1.0, 1.0),
266+
(1.0, 2.0),
267+
(1.0, -1.0),
268+
(100.0, 101.0),
269+
(0, 0),
270+
(1, 1),
271+
(1, 2),
272+
(1, -1),
273+
(1.0 + 1j, 1.0 + 1j),
274+
(1.0 + 1j, 1.0 - 1j),
275+
(float("inf"), float("inf")),
276+
(float("inf"), 100.0),
277+
(float("inf"), float("-inf")),
278+
(float("nan"), float("nan")),
279+
(float("nan"), 0.0),
280+
(0.0, float("nan")),
281+
(1e6, 1e6 + 1), # True - within rtol
282+
(1e6, 1e6 + 100), # False - outside rtol
283+
(1e-6, 1.1e-6), # False - outside atol
284+
(1e-7, 1.1e-7), # True - outside atol
285+
(1e6 + 0j, 1e6 + 1j), # True - within rtol
286+
(1e6 + 0j, 1e6 + 100j), # False - outside rtol
287+
],
288+
)
289+
def test_basic(self, a: float, b: float, xp: ModuleType):
290+
a_xp = xp.asarray(a)
291+
b_xp = xp.asarray(b)
292+
293+
xp_assert_equal(isclose(a_xp, b_xp), xp.asarray(np.isclose(a, b)))
294+
xp_assert_equal(allclose(a_xp, b_xp), xp.asarray(np.allclose(a, b)))
295+
296+
with warnings.catch_warnings():
297+
warnings.simplefilter("ignore")
298+
r_xp = xp.asarray(np.arange(10), dtype=a_xp.dtype)
299+
ar_xp = a_xp * r_xp
300+
br_xp = b_xp * r_xp
301+
ar_np = a * np.arange(10)
302+
br_np = b * np.arange(10)
303+
304+
xp_assert_equal(isclose(ar_xp, br_xp), xp.asarray(np.isclose(ar_np, br_np)))
305+
306+
@pytest.mark.parametrize("dtype", ["float32", "int32"])
307+
def test_broadcast(self, dtype: str, xp: ModuleType):
308+
dtype = getattr(xp, dtype)
309+
a = xp.asarray([1, 2, 3], dtype=dtype)
310+
b = xp.asarray([[1], [5]], dtype=dtype)
311+
actual = isclose(a, b)
312+
expect = xp.asarray(
313+
[[True, False, False], [False, False, False]], dtype=xp.bool
314+
)
315+
316+
xp_assert_equal(actual, expect)
317+
318+
# FIXME use lazywhere to avoid warnings on inf
319+
@pytest.mark.filterwarnings("ignore:invalid value encountered")
320+
def test_some_inf(self, xp: ModuleType):
321+
a = xp.asarray([0.0, 1.0, float("inf"), float("inf"), float("inf")])
322+
b = xp.asarray([1e-9, 1.0, float("inf"), float("-inf"), 2.0])
323+
actual = isclose(a, b)
324+
xp_assert_equal(actual, xp.asarray([True, True, True, False, False]))
325+
326+
def test_equal_nan(self, xp: ModuleType):
327+
a = xp.asarray([float("nan"), float("nan"), 1.0])
328+
b = xp.asarray([float("nan"), 1.0, float("nan")])
329+
xp_assert_equal(isclose(a, b), xp.asarray([False, False, False]))
330+
xp_assert_equal(isclose(a, b, equal_nan=True), xp.asarray([True, False, False]))
331+
xp_assert_equal(allclose(a[:1], b[:1]), xp.asarray(False))
332+
xp_assert_equal(allclose(a[:1], b[:1], equal_nan=True), xp.asarray(True))
333+
334+
@pytest.mark.parametrize("dtype", ["float32", "complex64", "int32"])
335+
def test_tolerance(self, dtype: str, xp: ModuleType):
336+
dtype = getattr(xp, dtype)
337+
a = xp.asarray([100, 100], dtype=dtype)
338+
b = xp.asarray([101, 102], dtype=dtype)
339+
xp_assert_equal(isclose(a, b), xp.asarray([False, False]))
340+
xp_assert_equal(isclose(a, b, atol=1), xp.asarray([True, False]))
341+
xp_assert_equal(isclose(a, b, rtol=0.01), xp.asarray([True, False]))
342+
xp_assert_equal(allclose(a[:1], b[:1]), xp.asarray(False))
343+
xp_assert_equal(allclose(a[:1], b[:1], atol=1), xp.asarray(True))
344+
xp_assert_equal(allclose(a[:1], b[:1], rtol=0.01), xp.asarray(True))
345+
346+
# Attempt to trigger division by 0 in rtol on int dtype
347+
xp_assert_equal(isclose(a, b, rtol=0), xp.asarray([False, False]))
348+
xp_assert_equal(isclose(a, b, atol=1, rtol=0), xp.asarray([True, False]))
349+
xp_assert_equal(allclose(a[:1], b[:1], rtol=0), xp.asarray(False))
350+
xp_assert_equal(allclose(a[:1], b[:1], atol=1, rtol=0), xp.asarray(True))
351+
352+
def test_very_small_numbers(self, xp: ModuleType):
353+
a = xp.asarray([1e-9, 1e-9])
354+
b = xp.asarray([1.0001e-9, 1.00001e-9])
355+
# Difference is below default atol
356+
xp_assert_equal(isclose(a, b), xp.asarray([True, True]))
357+
# Use only rtol
358+
xp_assert_equal(isclose(a, b, atol=0), xp.asarray([False, True]))
359+
xp_assert_equal(isclose(a, b, atol=0, rtol=0), xp.asarray([False, False]))
360+
361+
def test_xp(self, xp: ModuleType):
362+
a = xp.asarray([0.0, 0.0])
363+
b = xp.asarray([1e-9, 1e-4])
364+
xp_assert_equal(isclose(a, b, xp=xp), xp.asarray([True, False]))
365+
366+
255367
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no expand_dims")
256368
class TestKron:
257369
def test_basic(self, xp: ModuleType):

0 commit comments

Comments
 (0)