Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api-reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
cov
create_diagonal
expand_dims
isclose
kron
nunique
pad
Expand Down
2 changes: 1 addition & 1 deletion pixi.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ checks = [
"all", # report on all checks, except the below
"EX01", # most docstrings do not need an example
"SA01", # data-apis/array-api-extra#87
"SA04", # Missing description for See Also cross-reference
"ES01", # most docstrings do not need an extended summary
]
exclude = [ # don't report on objects that match any of these regex
Expand Down
3 changes: 2 additions & 1 deletion src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Extra array functions built on top of the array API standard."""

from ._delegation import pad
from ._delegation import isclose, pad
from ._lib._at import at
from ._lib._funcs import (
atleast_nd,
Expand All @@ -23,6 +23,7 @@
"cov",
"create_diagonal",
"expand_dims",
"isclose",
"kron",
"nunique",
"pad",
Expand Down
91 changes: 90 additions & 1 deletion src/array_api_extra/_delegation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ._lib._utils._compat import array_namespace
from ._lib._utils._typing import Array

__all__ = ["pad"]
__all__ = ["isclose", "pad"]


def _delegate(xp: ModuleType, *backends: Backend) -> bool:
Expand All @@ -30,6 +30,95 @@ def _delegate(xp: ModuleType, *backends: Backend) -> bool:
return any(backend.is_namespace(xp) for backend in backends)


def isclose(
a: Array,
b: Array,
*,
rtol: float = 1e-05,
atol: float = 1e-08,
equal_nan: bool = False,
xp: ModuleType | None = None,
) -> Array:
"""
Return a boolean array where two arrays are element-wise equal within a tolerance.

The tolerance values are positive, typically very small numbers. The relative
difference `(rtol * abs(b))` and the absolute difference atol are added together to
compare against the absolute difference between a and b.

NaNs are treated as equal if they are in the same place and if equal_nan=True. Infs
are treated as equal if they are in the same place and of the same sign in both
arrays.

Parameters
----------
a, b : Array
Input arrays to compare.
rtol : array_like, optional
The relative tolerance parameter (see Notes).
atol : array_like, optional
The absolute tolerance parameter (see Notes).
equal_nan : bool, optional
Whether to compare NaN's as equal. If True, NaN's in `a` will be considered
equal to NaN's in `b` in the output array.
xp : array_namespace, optional
The standard-compatible namespace for `a` and `b`. Default: infer.

Returns
-------
Array
A boolean array of shape broadcasted from `a` and `b`, containing `True` where
``a`` is close to ``b``, and `False` otherwise.

Warnings
--------
The default atol is not appropriate for comparing numbers with magnitudes much
smaller than one ) (see notes).

See Also
--------
math.isclose

Notes
-----
For finite values, `isclose` uses the following equation to test whether two
floating point values are equivalent::

absolute(a - b) <= (atol + rtol * absolute(b))

Unlike the built-in `math.isclose`, the above equation is not symmetric in a and b,
so that `isclose(a, b)` might be different from `isclose(b, a)` in some rare
cases.

The default value of `atol` is not appropriate when the reference value `b` has
magnitude smaller than one. For example, it is unlikely that ``a = 1e-9`` and
``b = 2e-9`` should be considered "close", yet ``isclose(1e-9, 2e-9)`` is `True`
with default settings. Be sure to select atol for the use case at hand, especially
for defining the threshold below which a non-zero value in `a` will be considered
"close" to a very small or zero value in `b`.

The comparison of `a` and `b` uses standard broadcasting, which means that `a` and
`b` need not have the same shape in order for `isclose(a, b)` to evaluate to
`True`.

`isclose` is not defined for non-numeric data types. `bool` is considered a numeric
data-type for this purpose.
"""
xp = array_namespace(a, b) if xp is None else xp

if _delegate(
xp,
Backend.NUMPY,
Backend.CUPY,
Backend.DASK,
Backend.JAX,
Backend.TORCH,
):
return xp.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)

return _funcs.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan, xp=xp)


def pad(
x: Array,
pad_width: int | tuple[int, int] | Sequence[tuple[int, int]],
Expand Down
35 changes: 35 additions & 0 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,41 @@ def expand_dims(
return a


def isclose(
a: Array,
b: Array,
*,
rtol: float = 1e-05,
atol: float = 1e-08,
equal_nan: bool = False,
xp: ModuleType | None = None,
) -> Array: # numpydoc ignore=PR01,RT01
"""See docstring in array_api_extra._delegation."""
xp = array_namespace(a, b) if xp is None else xp

a_inexact = xp.isdtype(a.dtype, ("real floating", "complex floating"))
b_inexact = xp.isdtype(b.dtype, ("real floating", "complex floating"))
if a_inexact or b_inexact:
# FIXME: use scipy's lazywhere to suppress warnings on inf
out = xp.abs(a - b) <= (atol + rtol * xp.abs(b))
out = xp.where(xp.isinf(a) & xp.isinf(b), xp.sign(a) == xp.sign(b), out)
if equal_nan:
out = xp.where(xp.isnan(a) & xp.isnan(b), xp.asarray(True), out)
return out

if xp.isdtype(a.dtype, "bool") or xp.isdtype(b.dtype, "bool"):
if atol >= 1 or rtol >= 1:
return xp.ones_like(a == b)
Copy link
Contributor Author

@crusaderky crusaderky Jan 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On eager backends, this is less performant than

return xp.ones(xp.broadcast_arrays(a, b)[0], dtype=bool, device=a.device)

but it supports backends with NaN shapes like Dask.
Both jax.jit and dask with non-NaN shape should elide the comparison away.

return a == b

# integer types
atol = int(atol)
if rtol == 0:
return xp.abs(a - b) <= atol
nrtol = int(1.0 / rtol)
return xp.abs(a - b) <= (atol + xp.abs(b) // nrtol)


def kron(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array:
"""
Kronecker product of two arrays.
Expand Down
15 changes: 15 additions & 0 deletions src/array_api_extra/_lib/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None:
The expected array (typically hardcoded).
err_msg : str, optional
Error message to display on failure.

See Also
--------
xp_assert_close
numpy.testing.assert_array_equal
"""
xp = _check_ns_shape_dtype(actual, desired)

Expand Down Expand Up @@ -112,6 +117,16 @@ def xp_assert_close(
Absolute tolerance. Default: 0.
err_msg : str, optional
Error message to display on failure.

See Also
--------
xp_assert_equal
isclose
numpy.testing.assert_allclose

Notes
-----
The default `atol` and `rtol` differ from `xp.all(xpx.allclose(a, b))`.
"""
xp = _check_ns_shape_dtype(actual, desired)

Expand Down
122 changes: 121 additions & 1 deletion tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
cov,
create_diagonal,
expand_dims,
isclose,
kron,
nunique,
pad,
Expand All @@ -23,7 +24,7 @@
from array_api_extra._lib._utils._typing import Array, Device

# some xp backends are untyped
# mypy: disable-error-code=no-untyped-usage
# mypy: disable-error-code=no-untyped-def


@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no expand_dims")
Expand Down Expand Up @@ -252,6 +253,125 @@ def test_xp(self, xp: ModuleType):
assert y.shape == (1, 1, 1, 3)


@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no isdtype")
class TestIsClose:
# FIXME use lazywhere to avoid warnings on inf
@pytest.mark.filterwarnings("ignore:invalid value encountered")
@pytest.mark.parametrize(
("a", "b"),
[
(0.0, 0.0),
(1.0, 1.0),
(1.0, 2.0),
(1.0, -1.0),
(100.0, 101.0),
(0, 0),
(1, 1),
(1, 2),
(1, -1),
(1.0 + 1j, 1.0 + 1j),
(1.0 + 1j, 1.0 - 1j),
(float("inf"), float("inf")),
(float("inf"), 100.0),
(float("inf"), float("-inf")),
(float("nan"), float("nan")),
(float("nan"), 0.0),
(0.0, float("nan")),
(1e6, 1e6 + 1), # True - within rtol
(1e6, 1e6 + 100), # False - outside rtol
(1e-6, 1.1e-6), # False - outside atol
(1e-7, 1.1e-7), # True - outside atol
(1e6 + 0j, 1e6 + 1j), # True - within rtol
(1e6 + 0j, 1e6 + 100j), # False - outside rtol
],
)
def test_basic(self, a: float, b: float, xp: ModuleType):
a_xp = xp.asarray(a)
b_xp = xp.asarray(b)

xp_assert_equal(isclose(a_xp, b_xp), xp.asarray(np.isclose(a, b)))

with warnings.catch_warnings():
warnings.simplefilter("ignore")
r_xp = xp.asarray(np.arange(10), dtype=a_xp.dtype)
ar_xp = a_xp * r_xp
br_xp = b_xp * r_xp
ar_np = a * np.arange(10)
br_np = b * np.arange(10)

xp_assert_equal(isclose(ar_xp, br_xp), xp.asarray(np.isclose(ar_np, br_np)))

@pytest.mark.parametrize("dtype", ["float32", "int32"])
def test_broadcast(self, dtype: str, xp: ModuleType):
dtype = getattr(xp, dtype)
a = xp.asarray([1, 2, 3], dtype=dtype)
b = xp.asarray([[1], [5]], dtype=dtype)
actual = isclose(a, b)
expect = xp.asarray(
[[True, False, False], [False, False, False]], dtype=xp.bool
)

xp_assert_equal(actual, expect)

# FIXME use lazywhere to avoid warnings on inf
@pytest.mark.filterwarnings("ignore:invalid value encountered")
def test_some_inf(self, xp: ModuleType):
a = xp.asarray([0.0, 1.0, float("inf"), float("inf"), float("inf")])
b = xp.asarray([1e-9, 1.0, float("inf"), float("-inf"), 2.0])
actual = isclose(a, b)
xp_assert_equal(actual, xp.asarray([True, True, True, False, False]))

def test_equal_nan(self, xp: ModuleType):
a = xp.asarray([float("nan"), float("nan"), 1.0])
b = xp.asarray([float("nan"), 1.0, float("nan")])
xp_assert_equal(isclose(a, b), xp.asarray([False, False, False]))
xp_assert_equal(isclose(a, b, equal_nan=True), xp.asarray([True, False, False]))

@pytest.mark.parametrize("dtype", ["float32", "complex64", "int32"])
def test_tolerance(self, dtype: str, xp: ModuleType):
dtype = getattr(xp, dtype)
a = xp.asarray([100, 100], dtype=dtype)
b = xp.asarray([101, 102], dtype=dtype)
xp_assert_equal(isclose(a, b), xp.asarray([False, False]))
xp_assert_equal(isclose(a, b, atol=1), xp.asarray([True, False]))
xp_assert_equal(isclose(a, b, rtol=0.01), xp.asarray([True, False]))

# Attempt to trigger division by 0 in rtol on int dtype
xp_assert_equal(isclose(a, b, rtol=0), xp.asarray([False, False]))
xp_assert_equal(isclose(a, b, atol=1, rtol=0), xp.asarray([True, False]))

def test_very_small_numbers(self, xp: ModuleType):
a = xp.asarray([1e-9, 1e-9])
b = xp.asarray([1.0001e-9, 1.00001e-9])
# Difference is below default atol
xp_assert_equal(isclose(a, b), xp.asarray([True, True]))
# Use only rtol
xp_assert_equal(isclose(a, b, atol=0), xp.asarray([False, True]))
xp_assert_equal(isclose(a, b, atol=0, rtol=0), xp.asarray([False, False]))

def test_bool_dtype(self, xp: ModuleType):
a = xp.asarray([False, True, False])
b = xp.asarray([True, True, False])
xp_assert_equal(isclose(a, b), xp.asarray([False, True, True]))
xp_assert_equal(isclose(a, b, atol=1), xp.asarray([True, True, True]))
xp_assert_equal(isclose(a, b, atol=2), xp.asarray([True, True, True]))
xp_assert_equal(isclose(a, b, rtol=1), xp.asarray([True, True, True]))
xp_assert_equal(isclose(a, b, rtol=2), xp.asarray([True, True, True]))

# Test broadcasting
xp_assert_equal(
isclose(a, xp.asarray(True), atol=1), xp.asarray([True, True, True])
)
xp_assert_equal(
isclose(xp.asarray(True), b, atol=1), xp.asarray([True, True, True])
)

def test_xp(self, xp: ModuleType):
a = xp.asarray([0.0, 0.0])
b = xp.asarray([1e-9, 1e-4])
xp_assert_equal(isclose(a, b, xp=xp), xp.asarray([True, False]))


@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no expand_dims")
class TestKron:
def test_basic(self, xp: ModuleType):
Expand Down
Loading