11"""Pytest fixtures."""
22
3+ from __future__ import annotations
4+
5+ from collections .abc import Callable
36from enum import Enum
4- from typing import cast
7+ from functools import wraps
8+ from typing import ParamSpec , TypeVar , cast
59
10+ import numpy as np
611import pytest
712
813from array_api_extra ._lib ._compat import array_namespace
914from array_api_extra ._lib ._compat import device as get_device
1015from array_api_extra ._lib ._typing import Device , ModuleType
1116
17+ T = TypeVar ("T" )
18+ P = ParamSpec ("P" )
19+
20+ np_compat = array_namespace (np .empty (0 ))
21+
1222
1323class Library (Enum ):
1424 """All array libraries explicitly tested by array-api-extra."""
@@ -50,6 +60,56 @@ def library(request: pytest.FixtureRequest) -> Library: # numpydoc ignore=PR01,
5060 return elem
5161
5262
63+ class NumPyReadOnly :
64+ """
65+ Variant of array_api_compat.numpy producing read-only arrays.
66+
67+ Note that this is not a full read-only Array API library. Notably,
68+ array_namespace(x) returns array_api_compat.numpy, and as a consequence array
69+ creation functions invoked internally by the tested functions will return
70+ writeable arrays, as long as you don't explicitly pass xp=xp.
71+ For this reason, tests that do pass xp=xp may misbehave and should be skipped
72+ for NUMPY_READONLY.
73+ """
74+
75+ def __getattr__ (self , name : str ) -> object : # numpydoc ignore=PR01,RT01
76+ """Wrap all functions that return arrays to make their output read-only."""
77+ func = getattr (np_compat , name )
78+ if not callable (func ) or isinstance (func , type ):
79+ return func
80+ return self ._wrap (func )
81+
82+ @staticmethod
83+ def _wrap (func : Callable [P , T ]) -> Callable [P , T ]: # numpydoc ignore=PR01,RT01
84+ """Wrap func to make all np.ndarrays it returns read-only."""
85+
86+ def as_readonly (o : T , seen : set [int ]) -> T : # numpydoc ignore=PR01,RT01
87+ """Unset the writeable flag in o."""
88+ if id (o ) in seen :
89+ return o
90+ seen .add (id (o ))
91+
92+ try :
93+ # Don't use is_numpy_array(o), as it includes np.generic
94+ if isinstance (o , np .ndarray ):
95+ o .flags .writeable = False
96+ except TypeError :
97+ # Cannot interpret as a data type
98+ return o
99+
100+ # This works with namedtuples too
101+ if isinstance (o , tuple | list ):
102+ return type (o )(* (as_readonly (i , seen ) for i in o )) # type: ignore[arg-type,return-value] # pyright: ignore[reportArgumentType,reportUnknownArgumentType]
103+
104+ return o
105+
106+ @wraps (func )
107+ def wrapper (* args : P .args , ** kwargs : P .kwargs ) -> T : # numpydoc ignore=GL08
108+ return as_readonly (func (* args , ** kwargs ), seen = set ())
109+
110+ return wrapper
111+
112+
53113@pytest .fixture
54114def xp (library : Library ) -> ModuleType : # numpydoc ignore=PR01,RT03
55115 """
@@ -59,8 +119,9 @@ def xp(library: Library) -> ModuleType: # numpydoc ignore=PR01,RT03
59119 -------
60120 The current array namespace.
61121 """
62- name = "numpy" if library == Library .NUMPY_READONLY else library .value
63- xp = pytest .importorskip (name )
122+ if library == Library .NUMPY_READONLY :
123+ return NumPyReadOnly () # type: ignore[return-value] # pyright: ignore[reportReturnType]
124+ xp = pytest .importorskip (library .value )
64125 if library == Library .JAX_NUMPY :
65126 import jax # type: ignore[import-not-found] # pyright: ignore[reportMissingImports]
66127
0 commit comments