@@ -10,7 +10,7 @@ Support for the array API standard
10
10
This guide describes how to **use ** and **add support for ** the
11
11
`Python array API standard <https://data-apis.org/array-api/latest/index.html >`_.
12
12
This standard allows users to use any array API compatible array library
13
- with SciPy out of the box.
13
+ with parts of SciPy out of the box.
14
14
15
15
The `RFC `_ defines how SciPy implements support for the standard, with the main
16
16
principle being *"array type in equals array type out" *. In addition, the
@@ -57,7 +57,7 @@ values:
57
57
58
58
Note that the above example works for PyTorch CPU tensors. For GPU tensors or
59
59
CuPy arrays, the expected result for ``vq `` is a ``TypeError ``, because ``vq ``
60
- is not a pure Python function and hence won't work on GPU.
60
+ uses compiled code in its implementation, which won't work on GPU.
61
61
62
62
More strict array input validation will reject ``np.matrix `` and
63
63
``np.ma.MaskedArray `` instances, as well as arrays with ``object `` dtype:
@@ -95,10 +95,12 @@ Currently supported functionality
95
95
The following modules provide array API standard support when the environment
96
96
variable is set:
97
97
98
- - `scipy.cluster.hierarchy `
99
- - `scipy.cluster.vq `
98
+ - `scipy.cluster `
100
99
- `scipy.constants `
100
+ - `scipy.datasets `
101
101
- `scipy.fft `
102
+ - `scipy.io `
103
+ - `scipy.ndimage `
102
104
103
105
Support is provided in `scipy.special ` for the following functions:
104
106
`scipy.special.log_ndtr `, `scipy.special.ndtr `, `scipy.special.ndtri `,
@@ -119,29 +121,31 @@ Support is provided in `scipy.stats` for the following functions:
119
121
`scipy.stats.jarque_bera `, `scipy.stats.bartlett `, `scipy.stats.power_divergence `,
120
122
and `scipy.stats.monte_carlo_test `.
121
123
124
+ Please see `the tracker issue `_ for updates.
125
+
122
126
123
127
Implementation notes
124
128
--------------------
125
129
126
130
A key part of the support for the array API standard and specific compatibility
127
131
functions for Numpy, CuPy and PyTorch is provided through
128
132
`array-api-compat <https://github.com/data-apis/array-api-compat >`_.
129
- This package is included in the SciPy code base via a git submodule (under
133
+ This package is included in the SciPy codebase via a git submodule (under
130
134
``scipy/_lib ``), so no new dependencies are introduced.
131
135
132
136
``array-api-compat `` provides generic utility functions and adds aliases such
133
- as ``xp.concat `` (which, for numpy, maps to ``np.concatenate ``). This allows
134
- using a uniform API across NumPy, PyTorch, CuPy and JAX (with other libraries ,
135
- such as Dask, coming in the future ).
137
+ as ``xp.concat `` (which, for numpy, mapped to ``np.concatenate `` before NumPy added
138
+ `` np.concat `` in NumPy 2.0). This allows using a uniform API across NumPy, PyTorch,
139
+ CuPy and JAX (with other libraries, such as Dask, being worked on ).
136
140
137
141
When the environment variable isn't set and hence array API standard support in
138
- SciPy is disabled, we still use the "augmented" version of the NumPy namespace,
142
+ SciPy is disabled, we still use the wrapped version of the NumPy namespace,
139
143
which is ``array_api_compat.numpy ``. That should not change behavior of SciPy
140
- functions, it's effectively the existing ``numpy `` namespace with a number of
144
+ functions, as it's effectively the existing ``numpy `` namespace with a number of
141
145
aliases added and a handful of functions amended/added for array API standard
142
- support. When support is enabled, depending on the type of arrays, ``xp `` will
143
- return the standard-compatible namespace matching the input array type to a
144
- function (e.g., if the input to `cluster.vq.kmeans ` is a PyTorch array , then
146
+ support. When support is enabled, ``xp = array_namespace(input) `` will
147
+ be the standard-compatible namespace matching the input array type to a
148
+ function (e.g., if the input to `cluster.vq.kmeans ` is a PyTorch tensor , then
145
149
``xp `` is ``array_api_compat.torch ``).
146
150
147
151
@@ -154,20 +158,9 @@ idioms for NumPy usage as well). By following the standard, effectively adding
154
158
support for the array API standard is typically straightforward, and we ideally
155
159
don't need to maintain any customization.
156
160
157
- Three helper functions are available:
158
-
159
- * ``array_namespace ``: return the namespace based on input arrays and do some
160
- input validation (like refusing to work with masked arrays, please see the
161
- `RFC `_.)
162
- * ``_asarray ``: a drop-in replacement for ``asarray `` with the additional
163
- parameters ``check_finite `` and ``order ``. As stated above, try to limit
164
- the use of non-standard features. In the end we would want to upstream our
165
- needs to the compatibility library. Passing ``xp=xp `` avoids duplicate calls
166
- of ``array_namespace `` internally.
167
- * ``copy ``: an alias for ``_asarray(x, copy=True) ``.
168
- The ``copy `` parameter was only introduced to ``np.asarray `` in NumPy 2.0,
169
- so use of the helper is needed to support ``<2.0 ``. Passing ``xp=xp `` avoids
170
- duplicate calls of ``array_namespace `` internally.
161
+ Various helper functions are available in ``scipy._lib._array_api `` - please see
162
+ the ``__all__ `` in that module for a list of current helpers, and their docstrings
163
+ for more information.
171
164
172
165
To add support to a SciPy function which is defined in a ``.py `` file, what you
173
166
have to change is:
@@ -183,11 +176,13 @@ Input array validation uses the following pattern::
183
176
# alternatively, if there are multiple array inputs, include them all:
184
177
xp = array_namespace(arr1, arr2)
185
178
179
+ # replace np.asarray with xp.asarray
180
+ arr = xp.asarray(arr)
186
181
# uses of non-standard parameters of np.asarray can be replaced with _asarray
187
182
arr = _asarray(arr, order='C', dtype=xp.float64, xp=xp)
188
183
189
- Note that if one input is a non-numpy array type, all array-like inputs have to
190
- be of that type; trying to mix non-numpy arrays with lists, Python scalars or
184
+ Note that if one input is a non-NumPy array type, all array-like inputs have to
185
+ be of that type; trying to mix non-NumPy arrays with lists, Python scalars or
191
186
other arbitrary Python objects will raise an exception. For NumPy arrays, those
192
187
types will continue to be accepted for backwards compatibility reasons.
193
188
@@ -218,7 +213,7 @@ You would convert this like so::
218
213
def toto(a, b):
219
214
xp = array_namespace(a, b)
220
215
a = xp.asarray(a)
221
- b = copy (b, xp=xp) # our custom helper is needed for copy
216
+ b = xp_copy (b, xp=xp) # our custom helper is needed for copy
222
217
223
218
c = xp.sum(a) - xp.prod(b)
224
219
@@ -231,7 +226,7 @@ You would convert this like so::
231
226
232
227
Going through compiled code requires going back to a NumPy array, because
233
228
SciPy's extension modules only work with NumPy arrays (or memoryviews in the
234
- case of Cython), but not with other array types . For arrays on CPU, the
229
+ case of Cython). For arrays on CPU, the
235
230
conversions should be zero-copy, while on GPU and other devices the attempt at
236
231
conversion will raise an exception. The reason for that is that silent data
237
232
transfer between devices is considered bad practice, as it is likely to be a
@@ -245,13 +240,13 @@ The following pytest markers are available:
245
240
246
241
* ``array_api_compatible -> xp ``: use a parametrisation to run a test on
247
242
multiple array backends.
248
- * ``skip_xp_backends(*backends, reasons=None, np_only=False, cpu_only=False) ``:
249
- skip certain backends and/or devices. ``np_only `` skips tests for all backends
250
- other than the default NumPy backend.
243
+ * ``skip_xp_backends(*backends, reasons=None, np_only=False, cpu_only=False, exceptions=None) ``:
244
+ skip certain backends and/or devices.
251
245
``@pytest.mark.usefixtures("skip_xp_backends") `` must be used alongside this
252
- marker for the skipping to apply.
246
+ marker for the skipping to apply. See the fixture's docstring in ``scipy.conftest ``
247
+ for information on how use this marker to skip tests.
253
248
* ``skip_xp_invalid_arg `` is used to skip tests that use arguments which
254
- are invalid when ``SCIPY_ARRAY_API `` is used . For instance, some tests of
249
+ are invalid when ``SCIPY_ARRAY_API `` is enabled . For instance, some tests of
255
250
`scipy.stats ` functions pass masked arrays to the function being tested, but
256
251
masked arrays are incompatible with the array API. Use of the
257
252
``skip_xp_invalid_arg `` decorator allows these tests to protect against
@@ -263,41 +258,58 @@ The following pytest markers are available:
263
258
default and only behavior, these tests (and the decorator itself) will be
264
259
removed.
265
260
266
- The following is an example using the markers::
261
+ ``scipy._lib._array_api `` contains array-agnostic assertions such as ``xp_assert_close ``
262
+ which can be used to replace assertions from `numpy.testing `.
263
+
264
+ The following examples demonstrate how to use the markers::
267
265
268
266
from scipy.conftest import array_api_compatible, skip_xp_invalid_arg
267
+ from scipy._lib._array_api import xp_assert_close
269
268
...
270
- @pytest.mark.skip_xp_backends(np_only=True,
271
- reasons=['skip reason'])
269
+ @pytest.mark.skip_xp_backends(np_only=True, reasons=['skip reason'])
272
270
@pytest.mark.usefixtures("skip_xp_backends")
273
271
@array_api_compatible
274
272
def test_toto1(self, xp):
275
273
a = xp.asarray([1, 2, 3])
276
274
b = xp.asarray([0, 2, 5])
277
- toto(a, b)
275
+ xp_assert_close( toto(a, b), a )
278
276
...
279
277
@pytest.mark.skip_xp_backends('array_api_strict', 'cupy',
280
- reasons=['skip reason 1',
281
- 'skip reason 2',])
278
+ reasons=['skip reason 1',
279
+ 'skip reason 2',], )
282
280
@pytest.mark.usefixtures("skip_xp_backends")
283
281
@array_api_compatible
284
282
def test_toto2(self, xp):
285
- a = xp.asarray([1, 2, 3])
286
- b = xp.asarray([0, 2, 5])
287
- toto(a, b)
283
+ ...
288
284
...
289
285
# Do not run when SCIPY_ARRAY_API is used
290
286
@skip_xp_invalid_arg
291
287
def test_toto_masked_array(self):
292
- a = np.ma.asarray([1, 2, 3])
293
- b = np.ma.asarray([0, 2, 5])
294
- toto(a, b)
288
+ ...
295
289
296
290
Passing a custom reason to ``reasons `` when ``cpu_only=True `` is unsupported
297
291
since ``cpu_only=True `` can be used alongside passing ``backends ``. Also,
298
292
the reason for using ``cpu_only `` is likely just that compiled code is used
299
293
in the function(s) being tested.
300
294
295
+ Passing names of backends into ``exceptions `` means that they will not be skipped
296
+ by ``cpu_only=True ``. This is useful when delegation is implemented for some,
297
+ but not all, non-CPU backends, and the CPU code path requires conversion to NumPy
298
+ for compiled code::
299
+
300
+ # array-api-strict and CuPy will always be skipped, for the given reasons.
301
+ # All libraries using a non-CPU device will also be skipped, apart from
302
+ # JAX, for which delegation is implemented (hence non-CPU execution is supported).
303
+ @pytest.mark.skip_xp_backends('array_api_strict', 'cupy',
304
+ reasons=['skip reason 1',
305
+ 'skip reason 2',],
306
+ cpu_only=True,
307
+ exceptions=['jax.numpy'],)
308
+ @pytest.mark.usefixtures("skip_xp_backends")
309
+ @array_api_compatible
310
+ def test_toto(self, xp):
311
+ ...
312
+
301
313
When every test function in a file has been updated for array API
302
314
compatibility, one can reduce verbosity by telling ``pytest `` to apply the
303
315
markers to every test function using ``pytestmark ``::
@@ -309,9 +321,7 @@ markers to every test function using ``pytestmark``::
309
321
...
310
322
@skip_xp_backends(np_only=True, reasons=['skip reason'])
311
323
def test_toto1(self, xp):
312
- a = xp.asarray([1, 2, 3])
313
- b = xp.asarray([0, 2, 5])
314
- toto(a, b)
324
+ ...
315
325
316
326
After applying these markers, ``dev.py test `` can be used with the new option
317
327
``-b `` or ``--array-api-backend ``::
@@ -321,12 +331,12 @@ After applying these markers, ``dev.py test`` can be used with the new option
321
331
This automatically sets ``SCIPY_ARRAY_API `` appropriately. To test a library
322
332
that has multiple devices with a non-default device, a second environment
323
333
variable (``SCIPY_DEVICE ``, only used in the test suite) can be set. Valid
324
- values depend on the array library under test, e.g. for PyTorch (currently the
325
- only library with multi-device support that is known to work) valid values are
326
- ``"cpu", "cuda", "mps" ``. So to run the test suite with the PyTorch MPS
334
+ values depend on the array library under test, e.g. for PyTorch, valid values are
335
+ ``"cpu", "cuda", "mps" ``. To run the test suite with the PyTorch MPS
327
336
backend, use: ``SCIPY_DEVICE=mps python dev.py test -b pytorch ``.
328
337
329
- Note that there is a GitHub Actions workflow which runs ``pytorch-cpu ``.
338
+ Note that there is a GitHub Actions workflow which tests with array-api-strict,
339
+ PyTorch, and JAX on CPU.
330
340
331
341
332
342
Additional information
@@ -346,3 +356,4 @@ helped during the development phase:
346
356
`#25956 <https://github.com/scikit-learn/scikit-learn/pull/25956 >`__
347
357
348
358
.. _RFC : https://github.com/scipy/scipy/issues/18286
359
+ .. _the tracker issue : https://github.com/scipy/scipy/issues/18867
0 commit comments