Skip to content

Commit 496a146

Browse files
authored
MAINT: _lib: vendor and use array-api-extra (scipy#21600)
1 parent 1c9b1aa commit 496a146

File tree

13 files changed

+58
-72
lines changed

13 files changed

+58
-72
lines changed

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,6 @@
2626
[submodule "scipy/_lib/cobyqa"]
2727
path = scipy/_lib/cobyqa
2828
url = https://github.com/cobyqa/cobyqa.git
29+
[submodule "scipy/_lib/array_api_extra"]
30+
path = scipy/_lib/array_api_extra
31+
url = https://github.com/lucascolley/array-api-extra.git

scipy/_lib/_array_api.py

Lines changed: 1 addition & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from __future__ import annotations
1010

1111
import os
12-
import warnings
1312

1413
from types import ModuleType
1514
from typing import Any, Literal, TYPE_CHECKING
@@ -36,7 +35,7 @@
3635
'is_array_api_strict', 'is_complex', 'is_cupy', 'is_jax', 'is_numpy', 'is_torch',
3736
'SCIPY_ARRAY_API', 'SCIPY_DEVICE', 'scipy_namespace_for',
3837
'xp_assert_close', 'xp_assert_equal', 'xp_assert_less',
39-
'xp_atleast_nd', 'xp_copy', 'xp_copysign', 'xp_device',
38+
'xp_copy', 'xp_copysign', 'xp_device',
4039
'xp_moveaxis_to_end', 'xp_ravel', 'xp_real', 'xp_sign', 'xp_size',
4140
'xp_take_along_axis', 'xp_unsupported_param_msg', 'xp_vector_norm',
4241
'xp_create_diagonal'
@@ -200,17 +199,6 @@ def _asarray(
200199
return array
201200

202201

203-
def xp_atleast_nd(x: Array, *, ndim: int, xp: ModuleType | None = None) -> Array:
204-
"""Recursively expand the dimension to have at least `ndim`."""
205-
if xp is None:
206-
xp = array_namespace(x)
207-
x = xp.asarray(x)
208-
if x.ndim < ndim:
209-
x = xp.expand_dims(x, axis=0)
210-
x = xp_atleast_nd(x, ndim=ndim, xp=xp)
211-
return x
212-
213-
214202
def xp_copy(x: Array, *, xp: ModuleType | None = None) -> Array:
215203
"""
216204
Copies an array.
@@ -382,34 +370,6 @@ def assert_almost_equal(actual, desired, decimal=7, *args, **kwds):
382370
*args, **kwds)
383371

384372

385-
def xp_cov(x: Array, *, xp: ModuleType | None = None) -> Array:
386-
if xp is None:
387-
xp = array_namespace(x)
388-
389-
X = xp_copy(x, xp=xp)
390-
dtype = xp.result_type(X, xp.float64)
391-
392-
X = xp_atleast_nd(X, ndim=2, xp=xp)
393-
X = xp.asarray(X, dtype=dtype)
394-
395-
avg = xp.mean(X, axis=1)
396-
fact = X.shape[1] - 1
397-
398-
if fact <= 0:
399-
warnings.warn("Degrees of freedom <= 0 for slice",
400-
RuntimeWarning, stacklevel=2)
401-
fact = 0.0
402-
403-
X -= avg[:, None]
404-
X_T = X.T
405-
if xp.isdtype(X_T.dtype, 'complex floating'):
406-
X_T = xp.conj(X_T)
407-
c = X @ X_T
408-
c /= fact
409-
axes = tuple(axis for axis, length in enumerate(c.shape) if length == 1)
410-
return xp.squeeze(c, axis=axes)
411-
412-
413373
def xp_unsupported_param_msg(param: Any) -> str:
414374
return f'Providing {param!r} is only supported for numpy arrays.'
415375

scipy/_lib/array_api_extra

Submodule array_api_extra added at b83ba61

scipy/_lib/meson.build

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ endif
1111
if not fs.exists('array_api_compat/README.md')
1212
error('Missing the `array_api_compat` submodule! Run `git submodule update --init` to fix this.')
1313
endif
14+
if not fs.exists('array_api_extra/README.md')
15+
error('Missing the `array_api_extra` submodule! Run `git submodule update --init` to fix this.')
16+
endif
1417
if not fs.exists('pocketfft/README.md')
1518
error('Missing the `pocketfft` submodule! Run `git submodule update --init` to fix this.')
1619
endif
@@ -213,6 +216,18 @@ py3.install_sources(
213216
subdir: 'scipy/_lib/array_api_compat/torch',
214217
)
215218

219+
# `array_api_extra` install to simplify import path;
220+
# should be updated whenever new files are added to `array_api_extra`
221+
222+
py3.install_sources(
223+
[
224+
'array_api_extra/src/array_api_extra/__init__.py',
225+
'array_api_extra/src/array_api_extra/_funcs.py',
226+
'array_api_extra/src/array_api_extra/_typing.py',
227+
],
228+
subdir: 'scipy/_lib/array_api_extra',
229+
)
230+
216231
py3.install_sources(
217232
[
218233
'cobyqa/cobyqa/__init__.py',

scipy/cluster/tests/test_vq.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
from scipy.conftest import array_api_compatible
1616
from scipy.sparse._sputils import matrix
1717

18+
from scipy._lib import array_api_extra as xpx
1819
from scipy._lib._array_api import (
19-
SCIPY_ARRAY_API, xp_copy, xp_cov, xp_assert_close, xp_assert_equal
20+
SCIPY_ARRAY_API, array_namespace, xp_copy, xp_assert_close, xp_assert_equal
2021
)
2122

2223
pytestmark = [array_api_compatible, pytest.mark.usefixtures("skip_xp_backends")]
@@ -352,11 +353,12 @@ def test_krandinit(self, xp):
352353
datas = [xp.reshape(data, (200, 2)),
353354
xp.reshape(data, (20, 20))[:10, :]]
354355
k = int(1e6)
356+
xp_test = array_namespace(data)
355357
for data in datas:
356358
rng = np.random.default_rng(1234)
357-
init = _krandinit(data, k, rng, xp)
358-
orig_cov = xp_cov(data.T)
359-
init_cov = xp_cov(init.T)
359+
init = _krandinit(data, k, rng, xp_test)
360+
orig_cov = xpx.cov(data.T, xp=xp_test)
361+
init_cov = xpx.cov(init.T, xp=xp_test)
360362
xp_assert_close(orig_cov, init_cov, atol=1.1e-2)
361363

362364
def test_kmeans2_empty(self, xp):

scipy/cluster/vq.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,11 @@
6868
import numpy as np
6969
from collections import deque
7070
from scipy._lib._array_api import (
71-
_asarray, array_namespace, xp_size, xp_atleast_nd, xp_copy, xp_cov
71+
_asarray, array_namespace, xp_size, xp_copy
7272
)
7373
from scipy._lib._util import (check_random_state, rng_integers,
7474
_transition_to_rng)
75+
from scipy._lib import array_api_extra as xpx
7576
from scipy.spatial.distance import cdist
7677

7778
from . import _vq
@@ -548,7 +549,7 @@ def _krandinit(data, k, rng, xp):
548549
k = np.asarray(k)
549550

550551
if data.ndim == 1:
551-
_cov = xp_cov(data)
552+
_cov = xpx.cov(data, xp=xp)
552553
x = rng.standard_normal(size=k)
553554
x = xp.asarray(x)
554555
x *= xp.sqrt(_cov)
@@ -560,7 +561,7 @@ def _krandinit(data, k, rng, xp):
560561
sVh = s[:, None] * vh / xp.sqrt(data.shape[0] - xp.asarray(1.))
561562
x = x @ sVh
562563
else:
563-
_cov = xp_atleast_nd(xp_cov(data.T), ndim=2)
564+
_cov = xpx.atleast_nd(xpx.cov(data.T, xp=xp), ndim=2, xp=xp)
564565

565566
# k rows, d cols (one row = one obs)
566567
# Generate k sample of a random variable ~ Gaussian(mu, cov)

scipy/optimize/_differentiable_functions.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
from ._numdiff import approx_derivative, group_columns
44
from ._hessian_update_strategy import HessianUpdateStrategy
55
from scipy.sparse.linalg import LinearOperator
6-
from scipy._lib._array_api import xp_atleast_nd, array_namespace
6+
from scipy._lib._array_api import array_namespace
7+
from scipy._lib import array_api_extra as xpx
78

89

910
FD_METHODS = ('2-point', '3-point', 'cs')
@@ -183,7 +184,7 @@ def __init__(self, fun, x0, args, grad, hess, finite_diff_rel_step,
183184
"quasi-Newton strategies.")
184185

185186
self.xp = xp = array_namespace(x0)
186-
_x = xp_atleast_nd(x0, ndim=1, xp=xp)
187+
_x = xpx.atleast_nd(xp.asarray(x0), ndim=1, xp=xp)
187188
_dtype = xp.float64
188189
if xp.isdtype(_x.dtype, "real floating"):
189190
_dtype = _x.dtype
@@ -274,7 +275,7 @@ def _update_x(self, x):
274275
# ensure that self.x is a copy of x. Don't store a reference
275276
# otherwise the memoization doesn't work properly.
276277

277-
_x = xp_atleast_nd(x, ndim=1, xp=self.xp)
278+
_x = xpx.atleast_nd(self.xp.asarray(x), ndim=1, xp=self.xp)
278279
self.x = self.xp.astype(_x, self.x_dtype)
279280
self.f_updated = False
280281
self.g_updated = False
@@ -283,7 +284,7 @@ def _update_x(self, x):
283284
else:
284285
# ensure that self.x is a copy of x. Don't store a reference
285286
# otherwise the memoization doesn't work properly.
286-
_x = xp_atleast_nd(x, ndim=1, xp=self.xp)
287+
_x = xpx.atleast_nd(self.xp.asarray(x), ndim=1, xp=self.xp)
287288
self.x = self.xp.astype(_x, self.x_dtype)
288289
self.f_updated = False
289290
self.g_updated = False
@@ -380,7 +381,7 @@ def __init__(self, fun, x0, jac, hess,
380381
"strategies.")
381382

382383
self.xp = xp = array_namespace(x0)
383-
_x = xp_atleast_nd(x0, ndim=1, xp=xp)
384+
_x = xpx.atleast_nd(xp.asarray(x0), ndim=1, xp=xp)
384385
_dtype = xp.float64
385386
if xp.isdtype(_x.dtype, "real floating"):
386387
_dtype = _x.dtype
@@ -557,15 +558,15 @@ def update_x(x):
557558
self._update_jac()
558559
self.x_prev = self.x
559560
self.J_prev = self.J
560-
_x = xp_atleast_nd(x, ndim=1, xp=self.xp)
561+
_x = xpx.atleast_nd(self.xp.asarray(x), ndim=1, xp=self.xp)
561562
self.x = self.xp.astype(_x, self.x_dtype)
562563
self.f_updated = False
563564
self.J_updated = False
564565
self.H_updated = False
565566
self._update_hess()
566567
else:
567568
def update_x(x):
568-
_x = xp_atleast_nd(x, ndim=1, xp=self.xp)
569+
_x = xpx.atleast_nd(self.xp.asarray(x), ndim=1, xp=self.xp)
569570
self.x = self.xp.astype(_x, self.x_dtype)
570571
self.f_updated = False
571572
self.J_updated = False
@@ -637,7 +638,7 @@ def __init__(self, A, x0, sparse_jacobian):
637638
self.m, self.n = self.J.shape
638639

639640
self.xp = xp = array_namespace(x0)
640-
_x = xp_atleast_nd(x0, ndim=1, xp=xp)
641+
_x = xpx.atleast_nd(xp.asarray(x0), ndim=1, xp=xp)
641642
_dtype = xp.float64
642643
if xp.isdtype(_x.dtype, "real floating"):
643644
_dtype = _x.dtype
@@ -654,7 +655,7 @@ def __init__(self, A, x0, sparse_jacobian):
654655

655656
def _update_x(self, x):
656657
if not np.array_equal(x, self.x):
657-
_x = xp_atleast_nd(x, ndim=1, xp=self.xp)
658+
_x = xpx.atleast_nd(self.xp.asarray(x), ndim=1, xp=self.xp)
658659
self.x = self.xp.astype(_x, self.x_dtype)
659660
self.f_updated = False
660661

scipy/optimize/_numdiff.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from scipy.sparse.linalg import LinearOperator
77
from ..sparse import issparse, csc_matrix, csr_matrix, coo_matrix, find
88
from ._group_columns import group_dense, group_sparse
9-
from scipy._lib._array_api import xp_atleast_nd, array_namespace
9+
from scipy._lib._array_api import array_namespace
10+
from scipy._lib import array_api_extra as xpx
1011

1112

1213
def _adjust_scheme_to_bounds(x0, h, num_steps, scheme, lb, ub):
@@ -440,7 +441,7 @@ def approx_derivative(fun, x0, method='3-point', rel_step=None, abs_step=None,
440441
raise ValueError(f"Unknown method '{method}'. ")
441442

442443
xp = array_namespace(x0)
443-
_x = xp_atleast_nd(x0, ndim=1, xp=xp)
444+
_x = xpx.atleast_nd(xp.asarray(x0), ndim=1, xp=xp)
444445
_dtype = xp.float64
445446
if xp.isdtype(_x.dtype, "real floating"):
446447
_dtype = _x.dtype

scipy/optimize/_optimize.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@
4141
from scipy._lib._util import (MapWrapper, check_random_state, _RichResult,
4242
_call_callback_maybe_halt, _transition_to_rng)
4343
from scipy.optimize._differentiable_functions import ScalarFunction, FD_METHODS
44-
from scipy._lib._array_api import (array_namespace, xp_atleast_nd,
45-
xp_create_diagonal)
44+
from scipy._lib._array_api import array_namespace, xp_create_diagonal
45+
from scipy._lib import array_api_extra as xpx
4646

4747

4848
# standard status messages of optimizers
@@ -442,7 +442,7 @@ def rosen_hess(x):
442442
443443
"""
444444
xp = array_namespace(x)
445-
x = xp_atleast_nd(x, ndim=1, xp=xp)
445+
x = xpx.atleast_nd(x, ndim=1, xp=xp)
446446
if xp.isdtype(x.dtype, 'integral'):
447447
x = xp.astype(x, xp.asarray(1.).dtype)
448448
H = (xp_create_diagonal(-400 * x[:-1], offset=1, xp=xp)
@@ -486,7 +486,7 @@ def rosen_hess_prod(x, p):
486486
487487
"""
488488
xp = array_namespace(x, p)
489-
x = xp_atleast_nd(x, ndim=1, xp=xp)
489+
x = xpx.atleast_nd(x, ndim=1, xp=xp)
490490
if xp.isdtype(x.dtype, 'integral'):
491491
x = xp.astype(x, xp.asarray(1.).dtype)
492492
p = xp.asarray(p, dtype=x.dtype)

scipy/optimize/_slsqp_py.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
_check_clip_x)
2525
from ._numdiff import approx_derivative
2626
from ._constraints import old_bound_to_new, _arr_to_scalar
27-
from scipy._lib._array_api import xp_atleast_nd, array_namespace
27+
from scipy._lib._array_api import array_namespace
28+
from scipy._lib import array_api_extra as xpx
2829

2930

3031
__docformat__ = "restructuredtext en"
@@ -250,7 +251,7 @@ def _minimize_slsqp(func, x0, args=(), jac=None, bounds=None,
250251

251252
# Transform x0 into an array.
252253
xp = array_namespace(x0)
253-
x0 = xp_atleast_nd(x0, ndim=1, xp=xp)
254+
x0 = xpx.atleast_nd(xp.asarray(x0), ndim=1, xp=xp)
254255
dtype = xp.float64
255256
if xp.isdtype(x0.dtype, "real floating"):
256257
dtype = x0.dtype

0 commit comments

Comments
 (0)