Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/array_api_extra/_lib/_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ def _op(
msg = f"Can't update read-only array {x}"
raise ValueError(msg)

# Backends without boolean indexing (other than JAX) crash here
if in_place_op: # add(), subtract(), ...
x[idx] = in_place_op(x[idx], y)
else: # set()
Expand Down
21 changes: 18 additions & 3 deletions src/array_api_extra/_lib/_backends.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Backends with which array-api-extra interacts in delegation and testing."""

from __future__ import annotations

from collections.abc import Callable
from enum import Enum
from types import ModuleType
from typing import cast

from ._utils import _compat

Expand All @@ -23,9 +24,14 @@ class Backend(Enum): # numpydoc ignore=PR01,PR02 # type: ignore[no-subclass-an
corresponding to the backend.
"""

# Use :<tag> to prevent Enum from deduplicating items with the same value
ARRAY_API_STRICT = "array_api_strict", _compat.is_array_api_strict_namespace
ARRAY_API_STRICTEST = (
"array_api_strict:strictest",
_compat.is_array_api_strict_namespace,
)
NUMPY = "numpy", _compat.is_numpy_namespace
NUMPY_READONLY = "numpy_readonly", _compat.is_numpy_namespace
NUMPY_READONLY = "numpy:readonly", _compat.is_numpy_namespace
CUPY = "cupy", _compat.is_cupy_namespace
TORCH = "torch", _compat.is_torch_namespace
DASK = "dask.array", _compat.is_dask_namespace
Expand All @@ -48,4 +54,13 @@ def __init__(

def __str__(self) -> str: # type: ignore[explicit-override] # pyright: ignore[reportImplicitOverride] # numpydoc ignore=RT01
"""Pretty-print parameterized test names."""
return cast(str, self.value)
return self.name.lower()

@property
def modname(self) -> str: # numpydoc ignore=RT01
"""Module name to be imported."""
return self.value.split(":")[0]

def like(self, *others: Backend) -> bool: # numpydoc ignore=PR01,RT01
"""Check if this backend uses the same module as others."""
return any(self.modname == other.modname for other in others)
51 changes: 36 additions & 15 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@

from ._at import at
from ._utils import _compat, _helpers
from ._utils._compat import (
array_namespace,
is_dask_namespace,
is_jax_array,
is_jax_namespace,
from ._utils._compat import array_namespace, is_dask_namespace, is_jax_array
from ._utils._helpers import (
asarrays,
capabilities,
eager_shape,
meta_namespace,
ndindex,
)
from ._utils._helpers import asarrays, eager_shape, meta_namespace, ndindex
from ._utils._typing import Array

__all__ = [
Expand Down Expand Up @@ -152,7 +153,7 @@ def _apply_where( # type: ignore[explicit-any] # numpydoc ignore=PR01,RT01
) -> Array:
"""Helper of `apply_where`. On Dask, this runs on a single chunk."""

if is_jax_namespace(xp):
if not capabilities(xp)["boolean indexing"]:
# jax.jit does not support assignment by boolean mask
return xp.where(cond, f1(*args), f2(*args) if f2 is not None else fill_value)

Expand Down Expand Up @@ -708,14 +709,34 @@ def nunique(x: Array, /, *, xp: ModuleType | None = None) -> Array:
# size= is JAX-specific
# https://github.com/data-apis/array-api/issues/883
_, counts = xp.unique_counts(x, size=_compat.size(x))
return xp.astype(counts, xp.bool).sum()

_, counts = xp.unique_counts(x)
n = _compat.size(counts)
# FIXME https://github.com/data-apis/array-api-compat/pull/231
if n is None: # e.g. Dask, ndonnx
return xp.astype(counts, xp.bool).sum()
return xp.asarray(n, device=_compat.device(x))
return (counts > 0).sum()

# There are 3 general use cases:
# 1. backend has unique_counts and it returns an array with known shape
# 2. backend has unique_counts and it returns a None-sized array;
# e.g. Dask, ndonnx
# 3. backend does not have unique_counts; e.g. wrapped JAX
if capabilities(xp)["data-dependent shapes"]:
# xp has unique_counts; O(n) complexity
_, counts = xp.unique_counts(x)
n = _compat.size(counts)
if n is None:
return xp.sum(xp.ones_like(counts))
return xp.asarray(n, device=_compat.device(x))

# xp does not have unique_counts; O(n*logn) complexity
x = xp.sort(xp.reshape(x, -1))
mask = x != xp.roll(x, -1)
default_int = xp.__array_namespace_info__().default_dtypes(
device=_compat.device(x)
)["integral"]
return xp.maximum(
# Special cases:
# - array is size 0
# - array has all elements equal to each other
xp.astype(xp.any(~mask), default_int),
xp.sum(xp.astype(mask, default_int)),
)


def pad(
Expand Down
33 changes: 33 additions & 0 deletions src/array_api_extra/_lib/_utils/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
array_namespace,
is_array_api_obj,
is_dask_namespace,
is_jax_namespace,
is_numpy_array,
is_pydata_sparse_namespace,
)
from ._typing import Array

Expand All @@ -23,6 +25,7 @@

__all__ = [
"asarrays",
"capabilities",
"eager_shape",
"in1d",
"is_python_scalar",
Expand Down Expand Up @@ -270,3 +273,33 @@ def meta_namespace(
# Quietly skip scalars and None's
metas = [cast(Array | None, getattr(a, "_meta", None)) for a in arrays]
return array_namespace(*metas)


def capabilities(xp: ModuleType) -> dict[str, int]:
"""
Return patched ``xp.__array_namespace_info__().capabilities()``.
Parameters
----------
xp : array_namespace
The standard-compatible namespace.
Returns
-------
dict
Capabilities of the namespace.
"""
if is_pydata_sparse_namespace(xp):
# No __array_namespace_info__(); no indexing by sparse arrays
return {"boolean indexing": False, "data-dependent shapes": True}
out = xp.__array_namespace_info__().capabilities()
if is_jax_namespace(xp):
# FIXME https://github.com/jax-ml/jax/issues/27418
out = out.copy()
out["boolean indexing"] = False
if is_dask_namespace(xp):
# FIXME https://github.com/data-apis/array-api-compat/pull/290
out = out.copy()
out["boolean indexing"] = True
out["data-dependent shapes"] = True
return out
32 changes: 23 additions & 9 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Pytest fixtures."""

from collections.abc import Callable
from collections.abc import Callable, Generator
from contextlib import suppress
from functools import partial, wraps
from types import ModuleType
Expand All @@ -19,6 +19,7 @@
T = TypeVar("T")
P = ParamSpec("P")

NUMPY_VERSION = tuple(int(v) for v in np.__version__.split(".")[2])
np_compat = array_namespace(np.empty(0)) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]


Expand All @@ -43,7 +44,7 @@ def library(request: pytest.FixtureRequest) -> Backend: # numpydoc ignore=PR01,
msg = f"argument of {marker_name} must be a Backend enum"
raise TypeError(msg)
if library == elem:
reason = library.value
reason = str(library)
with suppress(KeyError):
reason += ":" + cast(str, marker.kwargs["reason"])
skip_or_xfail(reason=reason)
Expand Down Expand Up @@ -104,7 +105,7 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08
@pytest.fixture
def xp(
library: Backend, request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch
) -> ModuleType: # numpydoc ignore=PR01,RT03
) -> Generator[ModuleType]: # numpydoc ignore=PR01,RT03
"""
Parameterized fixture that iterates on all libraries.

Expand All @@ -113,25 +114,38 @@ def xp(
The current array namespace.
"""
if library == Backend.NUMPY_READONLY:
return NumPyReadOnly() # type: ignore[return-value] # pyright: ignore[reportReturnType]
xp = pytest.importorskip(library.value)
yield NumPyReadOnly() # type: ignore[misc] # pyright: ignore[reportReturnType]
return

if library.like(Backend.ARRAY_API_STRICT) and NUMPY_VERSION < (1, 26):
pytest.skip("array_api_strict is untested on NumPy <1.26")

xp = pytest.importorskip(library.modname)
# Possibly wrap module with array_api_compat
xp = array_namespace(xp.empty(0))

if library == Backend.ARRAY_API_STRICTEST:
with xp.ArrayAPIStrictFlags(
boolean_indexing=False,
data_dependent_shapes=False,
# writeable=False, # TODO implement in array-api-strict
# lazy=True, # TODO implement in array-api-strict
enabled_extensions=(),
):
yield xp
return

# On Dask and JAX, monkey-patch all functions tagged by `lazy_xp_function`
# in the global scope of the module containing the test function.
patch_lazy_xp_functions(request, monkeypatch, xp=xp)

if library == Backend.ARRAY_API_STRICT and np.__version__ < "1.26":
pytest.skip("array_api_strict is untested on NumPy <1.26")

if library == Backend.JAX:
import jax

# suppress unused-ignore to run mypy in -e lint as well as -e dev
jax.config.update("jax_enable_x64", True) # type: ignore[no-untyped-call,unused-ignore]

return xp
yield xp


@pytest.fixture(params=[Backend.DASK]) # Can select the test with `pytest -k dask`
Expand Down
5 changes: 3 additions & 2 deletions tests/test_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
pytestmark = [
pytest.mark.skip_xp_backend(
Backend.SPARSE, reason="read-only backend without .at support"
)
),
pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="boolean indexing"),
]


Expand Down Expand Up @@ -256,7 +257,7 @@ def test_incompatible_dtype(
elif library is Backend.DASK:
z = at_op(x, idx, op, 1.1, copy=copy)

elif library is Backend.ARRAY_API_STRICT and op is not _AtOp.SET:
elif library.like(Backend.ARRAY_API_STRICT) and op is not _AtOp.SET:
with pytest.raises(Exception, match=r"cast|promote|dtype"):
_ = at_op(x, idx, op, 1.1, copy=copy)

Expand Down
52 changes: 42 additions & 10 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
from array_api_extra._lib._utils._typing import Array, Device
from array_api_extra.testing import lazy_xp_function

from .conftest import NUMPY_VERSION

# some xp backends are untyped
# mypy: disable-error-code=no-untyped-def

Expand All @@ -48,12 +50,6 @@
lazy_xp_function(sinc, static_argnames="xp")


NUMPY_GE2 = int(np.__version__.split(".")[0]) >= 2


@pytest.mark.skip_xp_backend(
Backend.SPARSE, reason="read-only backend without .at support"
)
class TestApplyWhere:
@staticmethod
def f1(x: Array, y: Array | int = 10) -> Array:
Expand Down Expand Up @@ -153,6 +149,14 @@ def test_dont_overwrite_fill_value(self, xp: ModuleType):
xp_assert_equal(actual, xp.asarray([100, 12]))
xp_assert_equal(fill_value, xp.asarray([100, 200]))

@pytest.mark.skip_xp_backend(
Backend.ARRAY_API_STRICTEST,
reason="no boolean indexing -> run everywhere",
)
@pytest.mark.skip_xp_backend(
Backend.SPARSE,
reason="no indexing by sparse array -> run everywhere",
)
def test_dont_run_on_false(self, xp: ModuleType):
x = xp.asarray([1.0, 2.0, 0.0])
y = xp.asarray([0.0, 3.0, 4.0])
Expand Down Expand Up @@ -192,6 +196,7 @@ def test_device(self, xp: ModuleType, device: Device):
y = apply_where(x % 2 == 0, x, self.f1, fill_value=x)
assert get_device(y) == device

@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no isdtype")
@pytest.mark.filterwarnings("ignore::RuntimeWarning") # overflows, etc.
@hypothesis.settings(
# The xp and library fixtures are not regenerated between hypothesis iterations
Expand All @@ -217,8 +222,8 @@ def test_hypothesis( # type: ignore[explicit-any,decorated-any]
library: Backend,
):
if (
library in (Backend.NUMPY, Backend.NUMPY_READONLY)
and not NUMPY_GE2
library.like(Backend.NUMPY)
and NUMPY_VERSION < (2, 0)
and dtype is np.float32
):
pytest.xfail(reason="NumPy 1.x dtype promotion for scalars")
Expand Down Expand Up @@ -562,6 +567,9 @@ def test_xp(self, xp: ModuleType):
assert y.shape == (1, 1, 1, 3)


@pytest.mark.filterwarnings( # array_api_strictest
"ignore:invalid value encountered:RuntimeWarning:array_api_strict"
)
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype")
class TestIsClose:
@pytest.mark.parametrize("swap", [False, True])
Expand Down Expand Up @@ -680,13 +688,15 @@ def test_bool_dtype(self, xp: ModuleType):
isclose(xp.asarray(True), b, atol=1), xp.asarray([True, True, True])
)

@pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="unknown shape")
def test_none_shape(self, xp: ModuleType):
a = xp.asarray([1, 5, 0])
b = xp.asarray([1, 4, 2])
b = b[a < 5]
a = a[a < 5]
xp_assert_equal(isclose(a, b), xp.asarray([True, False]))

@pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="unknown shape")
def test_none_shape_bool(self, xp: ModuleType):
a = xp.asarray([True, True, False])
b = xp.asarray([True, False, True])
Expand Down Expand Up @@ -819,8 +829,29 @@ def test_empty(self, xp: ModuleType):
a = xp.asarray([])
xp_assert_equal(nunique(a), xp.asarray(0))

def test_device(self, xp: ModuleType, device: Device):
a = xp.asarray(0.0, device=device)
def test_size1(self, xp: ModuleType):
a = xp.asarray([123])
xp_assert_equal(nunique(a), xp.asarray(1))

def test_all_equal(self, xp: ModuleType):
a = xp.asarray([123, 123, 123])
xp_assert_equal(nunique(a), xp.asarray(1))

@pytest.mark.xfail_xp_backend(Backend.DASK, reason="No equal_nan kwarg in unique")
@pytest.mark.xfail_xp_backend(
Backend.SPARSE, reason="Non-compliant equal_nan=True behaviour"
)
def test_nan(self, xp: ModuleType, library: Backend):
if library.like(Backend.NUMPY) and NUMPY_VERSION < (1, 24):
pytest.xfail("NumPy <1.24 has no equal_nan kwarg in unique")

# Each NaN is counted separately
a = xp.asarray([xp.nan, 123.0, xp.nan])
xp_assert_equal(nunique(a), xp.asarray(3))

@pytest.mark.parametrize("size", [0, 1, 2])
def test_device(self, xp: ModuleType, device: Device, size: int):
a = xp.asarray([0.0] * size, device=device)
assert get_device(nunique(a)) == device

def test_xp(self, xp: ModuleType):
Expand Down Expand Up @@ -895,6 +926,7 @@ def test_sequence_of_tuples_width(self, xp: ModuleType):


@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no argsort")
@pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="no unique_values")
class TestSetDiff1D:
@pytest.mark.xfail_xp_backend(Backend.DASK, reason="NaN-shaped arrays")
@pytest.mark.xfail_xp_backend(
Expand Down
Loading