Skip to content

Commit 76bf366

Browse files
authored
Merge pull request scipy#21310 from lucascolley/xp-docs
MAINT/DOC: clean up `_lib._array_api`, update docs
2 parents 94532e7 + 87be62b commit 76bf366

25 files changed

+156
-134
lines changed

doc/source/dev/api-dev/array_api.rst

Lines changed: 66 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Support for the array API standard
1010
This guide describes how to **use** and **add support for** the
1111
`Python array API standard <https://data-apis.org/array-api/latest/index.html>`_.
1212
This standard allows users to use any array API compatible array library
13-
with SciPy out of the box.
13+
with parts of SciPy out of the box.
1414

1515
The `RFC`_ defines how SciPy implements support for the standard, with the main
1616
principle being *"array type in equals array type out"*. In addition, the
@@ -57,7 +57,7 @@ values:
5757
5858
Note that the above example works for PyTorch CPU tensors. For GPU tensors or
5959
CuPy arrays, the expected result for ``vq`` is a ``TypeError``, because ``vq``
60-
is not a pure Python function and hence won't work on GPU.
60+
uses compiled code in its implementation, which won't work on GPU.
6161

6262
More strict array input validation will reject ``np.matrix`` and
6363
``np.ma.MaskedArray`` instances, as well as arrays with ``object`` dtype:
@@ -95,10 +95,12 @@ Currently supported functionality
9595
The following modules provide array API standard support when the environment
9696
variable is set:
9797

98-
- `scipy.cluster.hierarchy`
99-
- `scipy.cluster.vq`
98+
- `scipy.cluster`
10099
- `scipy.constants`
100+
- `scipy.datasets`
101101
- `scipy.fft`
102+
- `scipy.io`
103+
- `scipy.ndimage`
102104

103105
Support is provided in `scipy.special` for the following functions:
104106
`scipy.special.log_ndtr`, `scipy.special.ndtr`, `scipy.special.ndtri`,
@@ -119,29 +121,31 @@ Support is provided in `scipy.stats` for the following functions:
119121
`scipy.stats.jarque_bera`, `scipy.stats.bartlett`, `scipy.stats.power_divergence`,
120122
and `scipy.stats.monte_carlo_test`.
121123

124+
Please see `the tracker issue`_ for updates.
125+
122126

123127
Implementation notes
124128
--------------------
125129

126130
A key part of the support for the array API standard and specific compatibility
127131
functions for Numpy, CuPy and PyTorch is provided through
128132
`array-api-compat <https://github.com/data-apis/array-api-compat>`_.
129-
This package is included in the SciPy code base via a git submodule (under
133+
This package is included in the SciPy codebase via a git submodule (under
130134
``scipy/_lib``), so no new dependencies are introduced.
131135

132136
``array-api-compat`` provides generic utility functions and adds aliases such
133-
as ``xp.concat`` (which, for numpy, maps to ``np.concatenate``). This allows
134-
using a uniform API across NumPy, PyTorch, CuPy and JAX (with other libraries,
135-
such as Dask, coming in the future).
137+
as ``xp.concat`` (which, for numpy, mapped to ``np.concatenate`` before NumPy added
138+
``np.concat`` in NumPy 2.0). This allows using a uniform API across NumPy, PyTorch,
139+
CuPy and JAX (with other libraries, such as Dask, being worked on).
136140

137141
When the environment variable isn't set and hence array API standard support in
138-
SciPy is disabled, we still use the "augmented" version of the NumPy namespace,
142+
SciPy is disabled, we still use the wrapped version of the NumPy namespace,
139143
which is ``array_api_compat.numpy``. That should not change behavior of SciPy
140-
functions, it's effectively the existing ``numpy`` namespace with a number of
144+
functions, as it's effectively the existing ``numpy`` namespace with a number of
141145
aliases added and a handful of functions amended/added for array API standard
142-
support. When support is enabled, depending on the type of arrays, ``xp`` will
143-
return the standard-compatible namespace matching the input array type to a
144-
function (e.g., if the input to `cluster.vq.kmeans` is a PyTorch array, then
146+
support. When support is enabled, ``xp = array_namespace(input)`` will
147+
be the standard-compatible namespace matching the input array type to a
148+
function (e.g., if the input to `cluster.vq.kmeans` is a PyTorch tensor, then
145149
``xp`` is ``array_api_compat.torch``).
146150

147151

@@ -154,20 +158,9 @@ idioms for NumPy usage as well). By following the standard, effectively adding
154158
support for the array API standard is typically straightforward, and we ideally
155159
don't need to maintain any customization.
156160

157-
Three helper functions are available:
158-
159-
* ``array_namespace``: return the namespace based on input arrays and do some
160-
input validation (like refusing to work with masked arrays, please see the
161-
`RFC`_.)
162-
* ``_asarray``: a drop-in replacement for ``asarray`` with the additional
163-
parameters ``check_finite`` and ``order``. As stated above, try to limit
164-
the use of non-standard features. In the end we would want to upstream our
165-
needs to the compatibility library. Passing ``xp=xp`` avoids duplicate calls
166-
of ``array_namespace`` internally.
167-
* ``copy``: an alias for ``_asarray(x, copy=True)``.
168-
The ``copy`` parameter was only introduced to ``np.asarray`` in NumPy 2.0,
169-
so use of the helper is needed to support ``<2.0``. Passing ``xp=xp`` avoids
170-
duplicate calls of ``array_namespace`` internally.
161+
Various helper functions are available in ``scipy._lib._array_api`` - please see
162+
the ``__all__`` in that module for a list of current helpers, and their docstrings
163+
for more information.
171164

172165
To add support to a SciPy function which is defined in a ``.py`` file, what you
173166
have to change is:
@@ -183,11 +176,13 @@ Input array validation uses the following pattern::
183176
# alternatively, if there are multiple array inputs, include them all:
184177
xp = array_namespace(arr1, arr2)
185178

179+
# replace np.asarray with xp.asarray
180+
arr = xp.asarray(arr)
186181
# uses of non-standard parameters of np.asarray can be replaced with _asarray
187182
arr = _asarray(arr, order='C', dtype=xp.float64, xp=xp)
188183

189-
Note that if one input is a non-numpy array type, all array-like inputs have to
190-
be of that type; trying to mix non-numpy arrays with lists, Python scalars or
184+
Note that if one input is a non-NumPy array type, all array-like inputs have to
185+
be of that type; trying to mix non-NumPy arrays with lists, Python scalars or
191186
other arbitrary Python objects will raise an exception. For NumPy arrays, those
192187
types will continue to be accepted for backwards compatibility reasons.
193188

@@ -218,7 +213,7 @@ You would convert this like so::
218213
def toto(a, b):
219214
xp = array_namespace(a, b)
220215
a = xp.asarray(a)
221-
b = copy(b, xp=xp) # our custom helper is needed for copy
216+
b = xp_copy(b, xp=xp) # our custom helper is needed for copy
222217

223218
c = xp.sum(a) - xp.prod(b)
224219

@@ -231,7 +226,7 @@ You would convert this like so::
231226

232227
Going through compiled code requires going back to a NumPy array, because
233228
SciPy's extension modules only work with NumPy arrays (or memoryviews in the
234-
case of Cython), but not with other array types. For arrays on CPU, the
229+
case of Cython). For arrays on CPU, the
235230
conversions should be zero-copy, while on GPU and other devices the attempt at
236231
conversion will raise an exception. The reason for that is that silent data
237232
transfer between devices is considered bad practice, as it is likely to be a
@@ -245,13 +240,13 @@ The following pytest markers are available:
245240

246241
* ``array_api_compatible -> xp``: use a parametrisation to run a test on
247242
multiple array backends.
248-
* ``skip_xp_backends(*backends, reasons=None, np_only=False, cpu_only=False)``:
249-
skip certain backends and/or devices. ``np_only`` skips tests for all backends
250-
other than the default NumPy backend.
243+
* ``skip_xp_backends(*backends, reasons=None, np_only=False, cpu_only=False, exceptions=None)``:
244+
skip certain backends and/or devices.
251245
``@pytest.mark.usefixtures("skip_xp_backends")`` must be used alongside this
252-
marker for the skipping to apply.
246+
marker for the skipping to apply. See the fixture's docstring in ``scipy.conftest``
247+
for information on how use this marker to skip tests.
253248
* ``skip_xp_invalid_arg`` is used to skip tests that use arguments which
254-
are invalid when ``SCIPY_ARRAY_API`` is used. For instance, some tests of
249+
are invalid when ``SCIPY_ARRAY_API`` is enabled. For instance, some tests of
255250
`scipy.stats` functions pass masked arrays to the function being tested, but
256251
masked arrays are incompatible with the array API. Use of the
257252
``skip_xp_invalid_arg`` decorator allows these tests to protect against
@@ -263,41 +258,58 @@ The following pytest markers are available:
263258
default and only behavior, these tests (and the decorator itself) will be
264259
removed.
265260

266-
The following is an example using the markers::
261+
``scipy._lib._array_api`` contains array-agnostic assertions such as ``xp_assert_close``
262+
which can be used to replace assertions from `numpy.testing`.
263+
264+
The following examples demonstrate how to use the markers::
267265

268266
from scipy.conftest import array_api_compatible, skip_xp_invalid_arg
267+
from scipy._lib._array_api import xp_assert_close
269268
...
270-
@pytest.mark.skip_xp_backends(np_only=True,
271-
reasons=['skip reason'])
269+
@pytest.mark.skip_xp_backends(np_only=True, reasons=['skip reason'])
272270
@pytest.mark.usefixtures("skip_xp_backends")
273271
@array_api_compatible
274272
def test_toto1(self, xp):
275273
a = xp.asarray([1, 2, 3])
276274
b = xp.asarray([0, 2, 5])
277-
toto(a, b)
275+
xp_assert_close(toto(a, b), a)
278276
...
279277
@pytest.mark.skip_xp_backends('array_api_strict', 'cupy',
280-
reasons=['skip reason 1',
281-
'skip reason 2',])
278+
reasons=['skip reason 1',
279+
'skip reason 2',],)
282280
@pytest.mark.usefixtures("skip_xp_backends")
283281
@array_api_compatible
284282
def test_toto2(self, xp):
285-
a = xp.asarray([1, 2, 3])
286-
b = xp.asarray([0, 2, 5])
287-
toto(a, b)
283+
...
288284
...
289285
# Do not run when SCIPY_ARRAY_API is used
290286
@skip_xp_invalid_arg
291287
def test_toto_masked_array(self):
292-
a = np.ma.asarray([1, 2, 3])
293-
b = np.ma.asarray([0, 2, 5])
294-
toto(a, b)
288+
...
295289

296290
Passing a custom reason to ``reasons`` when ``cpu_only=True`` is unsupported
297291
since ``cpu_only=True`` can be used alongside passing ``backends``. Also,
298292
the reason for using ``cpu_only`` is likely just that compiled code is used
299293
in the function(s) being tested.
300294

295+
Passing names of backends into ``exceptions`` means that they will not be skipped
296+
by ``cpu_only=True``. This is useful when delegation is implemented for some,
297+
but not all, non-CPU backends, and the CPU code path requires conversion to NumPy
298+
for compiled code::
299+
300+
# array-api-strict and CuPy will always be skipped, for the given reasons.
301+
# All libraries using a non-CPU device will also be skipped, apart from
302+
# JAX, for which delegation is implemented (hence non-CPU execution is supported).
303+
@pytest.mark.skip_xp_backends('array_api_strict', 'cupy',
304+
reasons=['skip reason 1',
305+
'skip reason 2',],
306+
cpu_only=True,
307+
exceptions=['jax.numpy'],)
308+
@pytest.mark.usefixtures("skip_xp_backends")
309+
@array_api_compatible
310+
def test_toto(self, xp):
311+
...
312+
301313
When every test function in a file has been updated for array API
302314
compatibility, one can reduce verbosity by telling ``pytest`` to apply the
303315
markers to every test function using ``pytestmark``::
@@ -309,9 +321,7 @@ markers to every test function using ``pytestmark``::
309321
...
310322
@skip_xp_backends(np_only=True, reasons=['skip reason'])
311323
def test_toto1(self, xp):
312-
a = xp.asarray([1, 2, 3])
313-
b = xp.asarray([0, 2, 5])
314-
toto(a, b)
324+
...
315325

316326
After applying these markers, ``dev.py test`` can be used with the new option
317327
``-b`` or ``--array-api-backend``::
@@ -321,12 +331,12 @@ After applying these markers, ``dev.py test`` can be used with the new option
321331
This automatically sets ``SCIPY_ARRAY_API`` appropriately. To test a library
322332
that has multiple devices with a non-default device, a second environment
323333
variable (``SCIPY_DEVICE``, only used in the test suite) can be set. Valid
324-
values depend on the array library under test, e.g. for PyTorch (currently the
325-
only library with multi-device support that is known to work) valid values are
326-
``"cpu", "cuda", "mps"``. So to run the test suite with the PyTorch MPS
334+
values depend on the array library under test, e.g. for PyTorch, valid values are
335+
``"cpu", "cuda", "mps"``. To run the test suite with the PyTorch MPS
327336
backend, use: ``SCIPY_DEVICE=mps python dev.py test -b pytorch``.
328337

329-
Note that there is a GitHub Actions workflow which runs ``pytorch-cpu``.
338+
Note that there is a GitHub Actions workflow which tests with array-api-strict,
339+
PyTorch, and JAX on CPU.
330340

331341

332342
Additional information
@@ -346,3 +356,4 @@ helped during the development phase:
346356
`#25956 <https://github.com/scikit-learn/scikit-learn/pull/25956>`__
347357

348358
.. _RFC: https://github.com/scipy/scipy/issues/18286
359+
.. _the tracker issue: https://github.com/scipy/scipy/issues/18867

scipy/_lib/_array_api.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,21 @@
2020
from scipy._lib import array_api_compat
2121
from scipy._lib.array_api_compat import (
2222
is_array_api_obj,
23-
size,
23+
size as xp_size,
2424
numpy as np_compat,
25-
device
25+
device as xp_device
2626
)
2727

28-
__all__ = ['array_namespace', '_asarray', 'size', 'device']
28+
__all__ = [
29+
'_asarray', 'array_namespace', 'assert_almost_equal', 'assert_array_almost_equal',
30+
'get_xp_devices',
31+
'is_array_api_strict', 'is_complex', 'is_cupy', 'is_jax', 'is_numpy', 'is_torch',
32+
'SCIPY_ARRAY_API', 'SCIPY_DEVICE', 'scipy_namespace_for',
33+
'xp_assert_close', 'xp_assert_equal', 'xp_assert_less',
34+
'xp_atleast_nd', 'xp_copy', 'xp_copysign', 'xp_device',
35+
'xp_moveaxis_to_end', 'xp_ravel', 'xp_real', 'xp_sign', 'xp_size',
36+
'xp_take_along_axis', 'xp_unsupported_param_msg', 'xp_vector_norm',
37+
]
2938

3039

3140
# To enable array API and strict array-like input validation
@@ -44,7 +53,7 @@
4453
ArrayLike = Array | npt.ArrayLike
4554

4655

47-
def compliance_scipy(arrays: list[ArrayLike]) -> list[Array]:
56+
def _compliance_scipy(arrays: list[ArrayLike]) -> list[Array]:
4857
"""Raise exceptions on known-bad subclasses.
4958
5059
The following subclasses are not supported and raise and error:
@@ -111,7 +120,7 @@ def array_namespace(*arrays: Array) -> ModuleType:
111120
112121
1. Check for the global switch: SCIPY_ARRAY_API. This can also be accessed
113122
dynamically through ``_GLOBAL_CONFIG['SCIPY_ARRAY_API']``.
114-
2. `compliance_scipy` raise exceptions on known-bad subclasses. See
123+
2. `_compliance_scipy` raise exceptions on known-bad subclasses. See
115124
its definition for more details.
116125
117126
When the global switch is False, it defaults to the `numpy` namespace.
@@ -124,7 +133,7 @@ def array_namespace(*arrays: Array) -> ModuleType:
124133

125134
_arrays = [array for array in arrays if array is not None]
126135

127-
_arrays = compliance_scipy(_arrays)
136+
_arrays = _compliance_scipy(_arrays)
128137

129138
return array_api_compat.array_namespace(*_arrays)
130139

@@ -176,18 +185,18 @@ def _asarray(
176185
return array
177186

178187

179-
def atleast_nd(x: Array, *, ndim: int, xp: ModuleType | None = None) -> Array:
188+
def xp_atleast_nd(x: Array, *, ndim: int, xp: ModuleType | None = None) -> Array:
180189
"""Recursively expand the dimension to have at least `ndim`."""
181190
if xp is None:
182191
xp = array_namespace(x)
183192
x = xp.asarray(x)
184193
if x.ndim < ndim:
185194
x = xp.expand_dims(x, axis=0)
186-
x = atleast_nd(x, ndim=ndim, xp=xp)
195+
x = xp_atleast_nd(x, ndim=ndim, xp=xp)
187196
return x
188197

189198

190-
def copy(x: Array, *, xp: ModuleType | None = None) -> Array:
199+
def xp_copy(x: Array, *, xp: ModuleType | None = None) -> Array:
191200
"""
192201
Copies an array.
193202
@@ -207,7 +216,8 @@ def copy(x: Array, *, xp: ModuleType | None = None) -> Array:
207216
This copy function does not offer all the semantics of `np.copy`, i.e. the
208217
`subok` and `order` keywords are not used.
209218
"""
210-
# Note: xp.asarray fails if xp is numpy.
219+
# Note: for older NumPy versions, `np.asarray` did not support the `copy` kwarg,
220+
# so this uses our other helper `_asarray`.
211221
if xp is None:
212222
xp = array_namespace(x)
213223

@@ -395,14 +405,14 @@ def assert_almost_equal(actual, desired, decimal=7, *args, **kwds):
395405
*args, **kwds)
396406

397407

398-
def cov(x: Array, *, xp: ModuleType | None = None) -> Array:
408+
def xp_cov(x: Array, *, xp: ModuleType | None = None) -> Array:
399409
if xp is None:
400410
xp = array_namespace(x)
401411

402-
X = copy(x, xp=xp)
412+
X = xp_copy(x, xp=xp)
403413
dtype = xp.result_type(X, xp.float64)
404414

405-
X = atleast_nd(X, ndim=2, xp=xp)
415+
X = xp_atleast_nd(X, ndim=2, xp=xp)
406416
X = xp.asarray(X, dtype=dtype)
407417

408418
avg = xp.mean(X, axis=1)

scipy/_lib/_elementwise_iterative_method.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import math
1515
import numpy as np
1616
from ._util import _RichResult, _call_callback_maybe_halt
17-
from ._array_api import array_namespace, size as xp_size
17+
from ._array_api import array_namespace, xp_size
1818

1919
_ESIGNERR = -1
2020
_ECONVERR = -2

scipy/_lib/_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
)
1616

1717
import numpy as np
18-
from scipy._lib._array_api import array_namespace, is_numpy, size as xp_size
18+
from scipy._lib._array_api import array_namespace, is_numpy, xp_size
1919

2020

2121
AxisError: type[Exception]

scipy/_lib/tests/test__util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from scipy.conftest import array_api_compatible, skip_xp_invalid_arg
1414

1515
from scipy._lib._array_api import (xp_assert_equal, xp_assert_close, is_numpy,
16-
copy as xp_copy, is_array_api_strict)
16+
xp_copy, is_array_api_strict)
1717
from scipy._lib._util import (_aligned_zeros, check_random_state, MapWrapper,
1818
getfullargspec_no_self, FullArgSpec,
1919
rng_integers, _validate_int, _rename_parameter,

0 commit comments

Comments
 (0)