11"""Pytest fixtures."""
22
3+ from collections .abc import Callable
4+ from functools import wraps
35from types import ModuleType
4- from typing import cast
6+ from typing import ParamSpec , TypeVar , cast
57
8+ import numpy as np
69import pytest
710
811from array_api_extra ._lib import Backend
912from array_api_extra ._lib ._utils ._compat import array_namespace
1013from array_api_extra ._lib ._utils ._compat import device as get_device
1114from array_api_extra ._lib ._utils ._typing import Device
1215
16+ T = TypeVar ("T" )
17+ P = ParamSpec ("P" )
18+
19+ np_compat = array_namespace (np .empty (0 ))
20+
1321
1422@pytest .fixture (params = tuple (Backend ))
1523def library (request : pytest .FixtureRequest ) -> Backend : # numpydoc ignore=PR01,RT03
@@ -34,6 +42,56 @@ def library(request: pytest.FixtureRequest) -> Backend: # numpydoc ignore=PR01,
3442 return elem
3543
3644
45+ class NumPyReadOnly :
46+ """
47+ Variant of array_api_compat.numpy producing read-only arrays.
48+
49+ Read-only numpy arrays fail on `__iadd__` etc., whereas read-only libraries such as
50+ JAX and Sparse simply don't define those methods, which makes calls to `+=` fall
51+ back to `__add__`.
52+
53+ Note that this is not a full read-only Array API library. Notably,
54+ `array_namespace(x)` returns array_api_compat.numpy. This is actually the desired
55+ behaviour, so that when a tested function internally calls `xp =
56+ array_namespace(*args) or xp`, it will internally create writeable arrays.
57+ For this reason, tests that explicitly pass xp=xp to the tested functions may
58+ misbehave and should be skipped for NUMPY_READONLY.
59+ """
60+
61+ def __getattr__ (self , name : str ) -> object : # numpydoc ignore=PR01,RT01
62+ """Wrap all functions that return arrays to make their output read-only."""
63+ func = getattr (np_compat , name )
64+ if not callable (func ) or isinstance (func , type ):
65+ return func
66+ return self ._wrap (func )
67+
68+ @staticmethod
69+ def _wrap (func : Callable [P , T ]) -> Callable [P , T ]: # numpydoc ignore=PR01,RT01
70+ """Wrap func to make all np.ndarrays it returns read-only."""
71+
72+ def as_readonly (o : T ) -> T : # numpydoc ignore=PR01,RT01
73+ """Unset the writeable flag in o."""
74+ try :
75+ # Don't use is_numpy_array(o), as it includes np.generic
76+ if isinstance (o , np .ndarray ):
77+ o .flags .writeable = False
78+ except TypeError :
79+ # Cannot interpret as a data type
80+ return o
81+
82+ # This works with namedtuples too
83+ if isinstance (o , tuple | list ):
84+ return type (o )(* (as_readonly (i ) for i in o )) # type: ignore[arg-type,return-value] # pyright: ignore[reportArgumentType,reportUnknownArgumentType]
85+
86+ return o
87+
88+ @wraps (func )
89+ def wrapper (* args : P .args , ** kwargs : P .kwargs ) -> T : # numpydoc ignore=GL08
90+ return as_readonly (func (* args , ** kwargs ))
91+
92+ return wrapper
93+
94+
3795@pytest .fixture
3896def xp (library : Backend ) -> ModuleType : # numpydoc ignore=PR01,RT03
3997 """
@@ -43,7 +101,9 @@ def xp(library: Backend) -> ModuleType: # numpydoc ignore=PR01,RT03
43101 -------
44102 The current array namespace.
45103 """
46- xp = pytest .importorskip (library .module_name )
104+ if library == Backend .NUMPY_READONLY :
105+ return NumPyReadOnly () # type: ignore[return-value] # pyright: ignore[reportReturnType]
106+ xp = pytest .importorskip (library .value )
47107 if library == Backend .JAX_NUMPY :
48108 import jax # type: ignore[import-not-found] # pyright: ignore[reportMissingImports]
49109
0 commit comments