Skip to content

Commit 992412b

Browse files
committed
TST: Run all tests on read-only numpy arrays
1 parent ec890f1 commit 992412b

File tree

4 files changed

+75
-20
lines changed

4 files changed

+75
-20
lines changed

tests/conftest.py

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,24 @@
11
"""Pytest fixtures."""
22

3+
from __future__ import annotations
4+
5+
from collections.abc import Callable
36
from enum import Enum
4-
from typing import cast
7+
from functools import wraps
8+
from typing import ParamSpec, TypeVar, cast
59

10+
import numpy as np
611
import pytest
712

813
from array_api_extra._lib._compat import array_namespace
914
from array_api_extra._lib._compat import device as get_device
1015
from array_api_extra._lib._typing import Device, ModuleType
1116

17+
T = TypeVar("T")
18+
P = ParamSpec("P")
19+
20+
np_compat = array_namespace(np.empty(0))
21+
1222

1323
class Library(Enum):
1424
"""All array libraries explicitly tested by array-api-extra."""
@@ -50,6 +60,56 @@ def library(request: pytest.FixtureRequest) -> Library: # numpydoc ignore=PR01,
5060
return elem
5161

5262

63+
class NumPyReadOnly:
64+
"""
65+
Variant of array_api_compat.numpy producing read-only arrays.
66+
67+
Note that this is not a full read-only Array API library. Notably,
68+
array_namespace(x) returns array_api_compat.numpy, and as a consequence array
69+
creation functions invoked internally by the tested functions will return
70+
writeable arrays, as long as you don't explicitly pass xp=xp.
71+
For this reason, tests that do pass xp=xp may misbehave and should be skipped
72+
for NUMPY_READONLY.
73+
"""
74+
75+
def __getattr__(self, name: str) -> object: # numpydoc ignore=PR01,RT01
76+
"""Wrap all functions that return arrays to make their output read-only."""
77+
func = getattr(np_compat, name)
78+
if not callable(func) or isinstance(func, type):
79+
return func
80+
return self._wrap(func)
81+
82+
@staticmethod
83+
def _wrap(func: Callable[P, T]) -> Callable[P, T]: # numpydoc ignore=PR01,RT01
84+
"""Wrap func to make all np.ndarrays it returns read-only."""
85+
86+
def as_readonly(o: T, seen: set[int]) -> T: # numpydoc ignore=PR01,RT01
87+
"""Unset the writeable flag in o."""
88+
if id(o) in seen:
89+
return o
90+
seen.add(id(o))
91+
92+
try:
93+
# Don't use is_numpy_array(o), as it includes np.generic
94+
if isinstance(o, np.ndarray):
95+
o.flags.writeable = False
96+
except TypeError:
97+
# Cannot interpret as a data type
98+
return o
99+
100+
# This works with namedtuples too
101+
if isinstance(o, tuple | list):
102+
return type(o)(*(as_readonly(i, seen) for i in o)) # type: ignore[arg-type,return-value] # pyright: ignore[reportArgumentType,reportUnknownArgumentType]
103+
104+
return o
105+
106+
@wraps(func)
107+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08
108+
return as_readonly(func(*args, **kwargs), seen=set())
109+
110+
return wrapper
111+
112+
53113
@pytest.fixture
54114
def xp(library: Library) -> ModuleType: # numpydoc ignore=PR01,RT03
55115
"""
@@ -59,8 +119,9 @@ def xp(library: Library) -> ModuleType: # numpydoc ignore=PR01,RT03
59119
-------
60120
The current array namespace.
61121
"""
62-
name = "numpy" if library == Library.NUMPY_READONLY else library.value
63-
xp = pytest.importorskip(name)
122+
if library == Library.NUMPY_READONLY:
123+
return NumPyReadOnly() # type: ignore[return-value] # pyright: ignore[reportReturnType]
124+
xp = pytest.importorskip(library.value)
64125
if library == Library.JAX_NUMPY:
65126
import jax # type: ignore[import-not-found] # pyright: ignore[reportMissingImports]
66127

tests/test_at.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import pytest
77
from array_api_compat import ( # type: ignore[import-untyped] # pyright: ignore[reportMissingTypeStubs]
88
array_namespace,
9-
is_pydata_sparse_array,
109
is_writeable_array,
1110
)
1211

@@ -18,14 +17,6 @@
1817
from .conftest import Library
1918

2019

21-
@pytest.fixture
22-
def array(library: Library, xp: ModuleType) -> Array:
23-
x = xp.asarray([10.0, 20.0, 30.0])
24-
if library == Library.NUMPY_READONLY:
25-
x.flags.writeable = False
26-
return x
27-
28-
2920
@contextmanager
3021
def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
3122
if copy is False and not is_writeable_array(array):
@@ -42,6 +33,9 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
4233
xp_assert_equal(xp.all(array == array_orig), xp.asarray(copy))
4334

4435

36+
@pytest.mark.skip_xp_backend(
37+
Library.SPARSE, reason="read-only library without .at support"
38+
)
4539
@pytest.mark.parametrize(
4640
("kwargs", "expect_copy"),
4741
[
@@ -66,15 +60,13 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
6660
)
6761
def test_update_ops(
6862
xp: ModuleType,
69-
array: Array,
7063
kwargs: dict[str, bool | None],
7164
expect_copy: bool | None,
7265
op: _AtOp,
7366
arg: float,
7467
expect: list[float],
7568
):
76-
if is_pydata_sparse_array(array):
77-
pytest.skip("at() does not support updates on sparse arrays")
69+
array = xp.asarray([10.0, 20.0, 30.0])
7870

7971
with assert_copy(array, expect_copy):
8072
func = cast(Callable[..., Array], getattr(at(array)[1:], op.value)) # type: ignore[no-any-explicit]

tests/test_funcs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def test_device(self, xp: ModuleType, device: Device):
136136
x = xp.asarray([1, 2, 3], device=device)
137137
assert get_device(cov(x)) == device
138138

139+
@pytest.mark.skip_xp_backend(Library.NUMPY_READONLY)
139140
def test_xp(self, xp: ModuleType):
140141
xp_assert_close(
141142
cov(xp.asarray([[0.0, 2.0], [1.0, 1.0], [2.0, 0.0]]).T, xp=xp),
@@ -366,6 +367,7 @@ def test_device(self, xp: ModuleType, device: Device):
366367
x2 = xp.asarray([2, 3, 4], device=device)
367368
assert get_device(setdiff1d(x1, x2)) == device
368369

370+
@pytest.mark.skip_xp_backend(Library.NUMPY_READONLY)
369371
def test_xp(self, xp: ModuleType):
370372
x1 = xp.asarray([3, 8, 20])
371373
x2 = xp.asarray([2, 3, 4])

tests/test_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
import numpy as np
21
import pytest
32

43
from array_api_extra._lib._compat import device as get_device
54
from array_api_extra._lib._testing import xp_assert_equal
6-
from array_api_extra._lib._typing import Array, Device, ModuleType
5+
from array_api_extra._lib._typing import Device, ModuleType
76
from array_api_extra._lib._utils import in1d
87

98
from .conftest import Library
@@ -15,10 +14,10 @@ class TestIn1D:
1514
@pytest.mark.skip_xp_backend(Library.DASK_ARRAY, reason="no argsort")
1615
@pytest.mark.skip_xp_backend(Library.SPARSE, reason="no unique_inverse, no device")
1716
# cover both code paths
18-
@pytest.mark.parametrize("x2", [np.arange(9), np.arange(15)])
19-
def test_no_invert_assume_unique(self, xp: ModuleType, x2: Array):
17+
@pytest.mark.parametrize("n", [9, 15])
18+
def test_no_invert_assume_unique(self, xp: ModuleType, n: int):
2019
x1 = xp.asarray([3, 8, 20])
21-
x2 = xp.asarray(x2)
20+
x2 = xp.arange(n)
2221
expected = xp.asarray([True, True, False])
2322
actual = in1d(x1, x2)
2423
xp_assert_equal(actual, expected)
@@ -29,6 +28,7 @@ def test_device(self, xp: ModuleType, device: Device):
2928
x2 = xp.asarray([2, 3, 4], device=device)
3029
assert get_device(in1d(x1, x2)) == device
3130

31+
@pytest.mark.skip_xp_backend(Library.NUMPY_READONLY)
3232
@pytest.mark.skip_xp_backend(Library.SPARSE, reason="no arange, no device")
3333
def test_xp(self, xp: ModuleType):
3434
x1 = xp.asarray([1, 6])

0 commit comments

Comments
 (0)