1
1
"""Pytest fixtures."""
2
2
3
+ from collections .abc import Callable
4
+ from functools import wraps
3
5
from types import ModuleType
4
- from typing import cast
6
+ from typing import ParamSpec , TypeVar , cast
5
7
8
+ import numpy as np
6
9
import pytest
7
10
8
11
from array_api_extra ._lib import Backend
9
12
from array_api_extra ._lib ._utils ._compat import array_namespace
10
13
from array_api_extra ._lib ._utils ._compat import device as get_device
11
14
from array_api_extra ._lib ._utils ._typing import Device
12
15
16
+ T = TypeVar ("T" )
17
+ P = ParamSpec ("P" )
18
+
19
+ np_compat = array_namespace (np .empty (0 ))
20
+
13
21
14
22
@pytest .fixture (params = tuple (Backend ))
15
23
def library (request : pytest .FixtureRequest ) -> Backend : # numpydoc ignore=PR01,RT03
@@ -34,6 +42,56 @@ def library(request: pytest.FixtureRequest) -> Backend: # numpydoc ignore=PR01,
34
42
return elem
35
43
36
44
45
+ class NumPyReadOnly :
46
+ """
47
+ Variant of array_api_compat.numpy producing read-only arrays.
48
+
49
+ Read-only numpy arrays fail on `__iadd__` etc., whereas read-only libraries such as
50
+ JAX and Sparse simply don't define those methods, which makes calls to `+=` fall
51
+ back to `__add__`.
52
+
53
+ Note that this is not a full read-only Array API library. Notably,
54
+ `array_namespace(x)` returns array_api_compat.numpy. This is actually the desired
55
+ behaviour, so that when a tested function internally calls `xp =
56
+ array_namespace(*args) or xp`, it will internally create writeable arrays.
57
+ For this reason, tests that explicitly pass xp=xp to the tested functions may
58
+ misbehave and should be skipped for NUMPY_READONLY.
59
+ """
60
+
61
+ def __getattr__ (self , name : str ) -> object : # numpydoc ignore=PR01,RT01
62
+ """Wrap all functions that return arrays to make their output read-only."""
63
+ func = getattr (np_compat , name )
64
+ if not callable (func ) or isinstance (func , type ):
65
+ return func
66
+ return self ._wrap (func )
67
+
68
+ @staticmethod
69
+ def _wrap (func : Callable [P , T ]) -> Callable [P , T ]: # numpydoc ignore=PR01,RT01
70
+ """Wrap func to make all np.ndarrays it returns read-only."""
71
+
72
+ def as_readonly (o : T ) -> T : # numpydoc ignore=PR01,RT01
73
+ """Unset the writeable flag in o."""
74
+ try :
75
+ # Don't use is_numpy_array(o), as it includes np.generic
76
+ if isinstance (o , np .ndarray ):
77
+ o .flags .writeable = False
78
+ except TypeError :
79
+ # Cannot interpret as a data type
80
+ return o
81
+
82
+ # This works with namedtuples too
83
+ if isinstance (o , tuple | list ):
84
+ return type (o )(* (as_readonly (i ) for i in o )) # type: ignore[arg-type,return-value] # pyright: ignore[reportArgumentType,reportUnknownArgumentType]
85
+
86
+ return o
87
+
88
+ @wraps (func )
89
+ def wrapper (* args : P .args , ** kwargs : P .kwargs ) -> T : # numpydoc ignore=GL08
90
+ return as_readonly (func (* args , ** kwargs ))
91
+
92
+ return wrapper
93
+
94
+
37
95
@pytest .fixture
38
96
def xp (library : Backend ) -> ModuleType : # numpydoc ignore=PR01,RT03
39
97
"""
@@ -43,7 +101,9 @@ def xp(library: Backend) -> ModuleType: # numpydoc ignore=PR01,RT03
43
101
-------
44
102
The current array namespace.
45
103
"""
46
- xp = pytest .importorskip (library .module_name )
104
+ if library == Backend .NUMPY_READONLY :
105
+ return NumPyReadOnly () # type: ignore[return-value] # pyright: ignore[reportReturnType]
106
+ xp = pytest .importorskip (library .value )
47
107
if library == Backend .JAX_NUMPY :
48
108
import jax # type: ignore[import-not-found] # pyright: ignore[reportMissingImports]
49
109
0 commit comments