Skip to content

Commit 37174ad

Browse files
committed
MAINT: _lib._array_api: clean-up
1 parent 94532e7 commit 37174ad

23 files changed

+88
-79
lines changed

scipy/_lib/_array_api.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,21 @@
2020
from scipy._lib import array_api_compat
2121
from scipy._lib.array_api_compat import (
2222
is_array_api_obj,
23-
size,
23+
size as xp_size,
2424
numpy as np_compat,
25-
device
25+
device as xp_device
2626
)
2727

28-
__all__ = ['array_namespace', '_asarray', 'size', 'device']
28+
__all__ = [
29+
'_asarray', 'array_namespace', 'assert_almost_equal', 'assert_array_almost_equal',
30+
'get_xp_devices',
31+
'is_array_api_strict', 'is_complex', 'is_cupy', 'is_jax', 'is_numpy', 'is_torch',
32+
'SCIPY_ARRAY_API', 'SCIPY_DEVICE', 'scipy_namespace_for',
33+
'xp_assert_close', 'xp_assert_equal', 'xp_assert_less',
34+
'xp_atleast_nd', 'xp_copy', 'xp_copysign', 'xp_device',
35+
'xp_moveaxis_to_end', 'xp_ravel', 'xp_real', 'xp_sign', 'xp_size',
36+
'xp_take_along_axis', 'xp_unsupported_param_msg', 'xp_vector_norm',
37+
]
2938

3039

3140
# To enable array API and strict array-like input validation
@@ -44,7 +53,7 @@
4453
ArrayLike = Array | npt.ArrayLike
4554

4655

47-
def compliance_scipy(arrays: list[ArrayLike]) -> list[Array]:
56+
def _compliance_scipy(arrays: list[ArrayLike]) -> list[Array]:
4857
"""Raise exceptions on known-bad subclasses.
4958
5059
The following subclasses are not supported and raise and error:
@@ -111,7 +120,7 @@ def array_namespace(*arrays: Array) -> ModuleType:
111120
112121
1. Check for the global switch: SCIPY_ARRAY_API. This can also be accessed
113122
dynamically through ``_GLOBAL_CONFIG['SCIPY_ARRAY_API']``.
114-
2. `compliance_scipy` raise exceptions on known-bad subclasses. See
123+
2. `_compliance_scipy` raise exceptions on known-bad subclasses. See
115124
its definition for more details.
116125
117126
When the global switch is False, it defaults to the `numpy` namespace.
@@ -124,7 +133,7 @@ def array_namespace(*arrays: Array) -> ModuleType:
124133

125134
_arrays = [array for array in arrays if array is not None]
126135

127-
_arrays = compliance_scipy(_arrays)
136+
_arrays = _compliance_scipy(_arrays)
128137

129138
return array_api_compat.array_namespace(*_arrays)
130139

@@ -176,18 +185,18 @@ def _asarray(
176185
return array
177186

178187

179-
def atleast_nd(x: Array, *, ndim: int, xp: ModuleType | None = None) -> Array:
188+
def xp_atleast_nd(x: Array, *, ndim: int, xp: ModuleType | None = None) -> Array:
180189
"""Recursively expand the dimension to have at least `ndim`."""
181190
if xp is None:
182191
xp = array_namespace(x)
183192
x = xp.asarray(x)
184193
if x.ndim < ndim:
185194
x = xp.expand_dims(x, axis=0)
186-
x = atleast_nd(x, ndim=ndim, xp=xp)
195+
x = xp_atleast_nd(x, ndim=ndim, xp=xp)
187196
return x
188197

189198

190-
def copy(x: Array, *, xp: ModuleType | None = None) -> Array:
199+
def xp_copy(x: Array, *, xp: ModuleType | None = None) -> Array:
191200
"""
192201
Copies an array.
193202
@@ -207,7 +216,8 @@ def copy(x: Array, *, xp: ModuleType | None = None) -> Array:
207216
This copy function does not offer all the semantics of `np.copy`, i.e. the
208217
`subok` and `order` keywords are not used.
209218
"""
210-
# Note: xp.asarray fails if xp is numpy.
219+
# Note: for older NumPy versions, `np.asarray` did not support the `copy` kwarg,
220+
# so this uses our other helper `_asarray`.
211221
if xp is None:
212222
xp = array_namespace(x)
213223

@@ -395,14 +405,14 @@ def assert_almost_equal(actual, desired, decimal=7, *args, **kwds):
395405
*args, **kwds)
396406

397407

398-
def cov(x: Array, *, xp: ModuleType | None = None) -> Array:
408+
def xp_cov(x: Array, *, xp: ModuleType | None = None) -> Array:
399409
if xp is None:
400410
xp = array_namespace(x)
401411

402-
X = copy(x, xp=xp)
412+
X = xp_copy(x, xp=xp)
403413
dtype = xp.result_type(X, xp.float64)
404414

405-
X = atleast_nd(X, ndim=2, xp=xp)
415+
X = xp_atleast_nd(X, ndim=2, xp=xp)
406416
X = xp.asarray(X, dtype=dtype)
407417

408418
avg = xp.mean(X, axis=1)

scipy/_lib/_elementwise_iterative_method.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import math
1515
import numpy as np
1616
from ._util import _RichResult, _call_callback_maybe_halt
17-
from ._array_api import array_namespace, size as xp_size
17+
from ._array_api import array_namespace, xp_size
1818

1919
_ESIGNERR = -1
2020
_ECONVERR = -2

scipy/_lib/_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
)
1616

1717
import numpy as np
18-
from scipy._lib._array_api import array_namespace, is_numpy, size as xp_size
18+
from scipy._lib._array_api import array_namespace, is_numpy, xp_size
1919

2020

2121
AxisError: type[Exception]

scipy/_lib/tests/test__util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from scipy.conftest import array_api_compatible, skip_xp_invalid_arg
1414

1515
from scipy._lib._array_api import (xp_assert_equal, xp_assert_close, is_numpy,
16-
copy as xp_copy, is_array_api_strict)
16+
xp_copy, is_array_api_strict)
1717
from scipy._lib._util import (_aligned_zeros, check_random_state, MapWrapper,
1818
getfullargspec_no_self, FullArgSpec,
1919
rng_integers, _validate_int, _rename_parameter,

scipy/_lib/tests/test_array_api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from scipy.conftest import array_api_compatible
55
from scipy._lib._array_api import (
6-
_GLOBAL_CONFIG, array_namespace, _asarray, copy, xp_assert_equal, is_numpy
6+
_GLOBAL_CONFIG, array_namespace, _asarray, xp_copy, xp_assert_equal, is_numpy
77
)
88
import scipy._lib.array_api_compat.numpy as np_compat
99

@@ -60,7 +60,7 @@ def test_array_likes(self):
6060
def test_copy(self, xp):
6161
for _xp in [xp, None]:
6262
x = xp.asarray([1, 2, 3])
63-
y = copy(x, xp=_xp)
63+
y = xp_copy(x, xp=_xp)
6464
# with numpy we'd want to use np.shared_memory, but that's not specified
6565
# in the array-api
6666
x[0] = 10

scipy/cluster/hierarchy.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@
134134
import numpy as np
135135
from . import _hierarchy, _optimal_leaf_ordering
136136
import scipy.spatial.distance as distance
137-
from scipy._lib._array_api import array_namespace, _asarray, copy, is_jax
137+
from scipy._lib._array_api import array_namespace, _asarray, xp_copy, is_jax
138138
from scipy._lib._disjoint_set import DisjointSet
139139

140140

@@ -1358,7 +1358,7 @@ def cut_tree(Z, n_clusters=None, height=None):
13581358

13591359
for i, node in enumerate(nodes):
13601360
idx = node.pre_order()
1361-
this_group = copy(last_group, xp=xp)
1361+
this_group = xp_copy(last_group, xp=xp)
13621362
# TODO ARRAY_API complex indexing not supported
13631363
this_group[idx] = xp.min(last_group[idx])
13641364
this_group[this_group > xp.max(last_group[idx])] -= 1
@@ -1822,14 +1822,14 @@ def from_mlab_linkage(Z):
18221822

18231823
# If it's empty, return it.
18241824
if len(Zs) == 0 or (len(Zs) == 1 and Zs[0] == 0):
1825-
return copy(Z, xp=xp)
1825+
return xp_copy(Z, xp=xp)
18261826

18271827
if len(Zs) != 2:
18281828
raise ValueError("The linkage array must be rectangular.")
18291829

18301830
# If it contains no rows, return it.
18311831
if Zs[0] == 0:
1832-
return copy(Z, xp=xp)
1832+
return xp_copy(Z, xp=xp)
18331833

18341834
if xp.min(Z[:, 0:2]) != 1.0 and xp.max(Z[:, 0:2]) != 2 * Zs[0]:
18351835
raise ValueError('The format of the indices is not 1..N')
@@ -1925,7 +1925,7 @@ def to_mlab_linkage(Z):
19251925
Z = _asarray(Z, order='C', dtype=xp.float64, xp=xp)
19261926
Zs = Z.shape
19271927
if len(Zs) == 0 or (len(Zs) == 1 and Zs[0] == 0):
1928-
return copy(Z, xp=xp)
1928+
return xp_copy(Z, xp=xp)
19291929
is_valid_linkage(Z, throw=True, name='Z')
19301930

19311931
return xp.concat((Z[:, :2] + 1.0, Z[:, 2:3]), axis=1)

scipy/cluster/tests/test_vq.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from scipy.sparse._sputils import matrix
1717

1818
from scipy._lib._array_api import (
19-
SCIPY_ARRAY_API, copy, cov, xp_assert_close, xp_assert_equal
19+
SCIPY_ARRAY_API, xp_copy, xp_cov, xp_assert_close, xp_assert_equal
2020
)
2121

2222
pytestmark = [array_api_compatible, pytest.mark.usefixtures("skip_xp_backends")]
@@ -307,7 +307,7 @@ def test_kmeans2_rank1(self, xp):
307307
data1 = data[:, 0]
308308

309309
initc = data1[:3]
310-
code = copy(initc, xp=xp)
310+
code = xp_copy(initc, xp=xp)
311311
kmeans2(data1, code, iter=1)[0]
312312
kmeans2(data1, code, iter=2)[0]
313313

@@ -353,8 +353,8 @@ def test_krandinit(self, xp):
353353
for data in datas:
354354
rng = np.random.default_rng(1234)
355355
init = _krandinit(data, k, rng, xp)
356-
orig_cov = cov(data.T)
357-
init_cov = cov(init.T)
356+
orig_cov = xp_cov(data.T)
357+
init_cov = xp_cov(init.T)
358358
xp_assert_close(orig_cov, init_cov, atol=1.1e-2)
359359

360360
def test_kmeans2_empty(self, xp):

scipy/cluster/vq.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@
6868
import numpy as np
6969
from collections import deque
7070
from scipy._lib._array_api import (
71-
_asarray, array_namespace, size, atleast_nd, copy, cov
71+
_asarray, array_namespace, xp_size, xp_atleast_nd, xp_copy, xp_cov
7272
)
7373
from scipy._lib._util import check_random_state, rng_integers
7474
from scipy.spatial.distance import cdist
@@ -472,8 +472,8 @@ def kmeans(obs, k_or_guess, iter=20, thresh=1e-5, check_finite=True,
472472
raise ValueError(f"iter must be at least 1, got {iter}")
473473

474474
# Determine whether a count (scalar) or an initial guess (array) was passed.
475-
if size(guess) != 1:
476-
if size(guess) < 1:
475+
if xp_size(guess) != 1:
476+
if xp_size(guess) < 1:
477477
raise ValueError(f"Asked for 0 clusters. Initial book was {guess}")
478478
return _kmeans(obs, guess, thresh=thresh, xp=xp)
479479

@@ -551,23 +551,23 @@ def _krandinit(data, k, rng, xp):
551551
k = np.asarray(k)
552552

553553
if data.ndim == 1:
554-
_cov = cov(data)
554+
_cov = xp_cov(data)
555555
x = rng.standard_normal(size=k)
556556
x = xp.asarray(x)
557557
x *= xp.sqrt(_cov)
558558
elif data.shape[1] > data.shape[0]:
559559
# initialize when the covariance matrix is rank deficient
560560
_, s, vh = xp.linalg.svd(data - mu, full_matrices=False)
561-
x = rng.standard_normal(size=(k, size(s)))
561+
x = rng.standard_normal(size=(k, xp_size(s)))
562562
x = xp.asarray(x)
563563
sVh = s[:, None] * vh / xp.sqrt(data.shape[0] - xp.asarray(1.))
564564
x = x @ sVh
565565
else:
566-
_cov = atleast_nd(cov(data.T), ndim=2)
566+
_cov = xp_atleast_nd(xp_cov(data.T), ndim=2)
567567

568568
# k rows, d cols (one row = one obs)
569569
# Generate k sample of a random variable ~ Gaussian(mu, cov)
570-
x = rng.standard_normal(size=(k, size(mu)))
570+
x = rng.standard_normal(size=(k, xp_size(mu)))
571571
x = xp.asarray(x)
572572
x = x @ xp.linalg.cholesky(_cov).T
573573

@@ -782,19 +782,19 @@ def kmeans2(data, k, iter=10, thresh=1e-5, minit='random',
782782
else:
783783
xp = array_namespace(data, k)
784784
data = _asarray(data, xp=xp, check_finite=check_finite)
785-
code_book = copy(k, xp=xp)
785+
code_book = xp_copy(k, xp=xp)
786786
if data.ndim == 1:
787787
d = 1
788788
elif data.ndim == 2:
789789
d = data.shape[1]
790790
else:
791791
raise ValueError("Input of rank > 2 is not supported.")
792792

793-
if size(data) < 1 or size(code_book) < 1:
793+
if xp_size(data) < 1 or xp_size(code_book) < 1:
794794
raise ValueError("Empty input is not supported.")
795795

796796
# If k is not a single value, it should be compatible with data's shape
797-
if minit == 'matrix' or size(code_book) > 1:
797+
if minit == 'matrix' or xp_size(code_book) > 1:
798798
if data.ndim != code_book.ndim:
799799
raise ValueError("k array doesn't match data rank")
800800
nc = code_book.shape[0]

scipy/fft/_fftlog_backend.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from ._basic import rfft, irfft
44
from ..special import loggamma, poch
55

6-
from scipy._lib._array_api import array_namespace, copy
6+
from scipy._lib._array_api import array_namespace
77

88
__all__ = ['fht', 'ifht', 'fhtoffset']
99

@@ -106,12 +106,12 @@ def fhtcoeff(n, dln, mu, offset=0.0, bias=0.0, inverse=False):
106106
if np.isinf(u[0]) and not inverse:
107107
warn('singular transform; consider changing the bias', stacklevel=3)
108108
# fix coefficient to obtain (potentially correct) transform anyway
109-
u = copy(u)
109+
u = np.copy(u)
110110
u[0] = 0
111111
elif u[0] == 0 and inverse:
112112
warn('singular inverse transform; consider changing the bias', stacklevel=3)
113113
# fix coefficient to obtain (potentially correct) inverse anyway
114-
u = copy(u)
114+
u = np.copy(u)
115115
u[0] = np.inf
116116

117117
return u

scipy/fft/tests/test_basic.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import scipy.fft as fft
1010
from scipy.conftest import array_api_compatible
1111
from scipy._lib._array_api import (
12-
array_namespace, size, xp_assert_close, xp_assert_equal
12+
array_namespace, xp_size, xp_assert_close, xp_assert_equal
1313
)
1414

1515
pytestmark = [array_api_compatible, pytest.mark.usefixtures("skip_xp_backends")]
@@ -123,7 +123,7 @@ def test_ifftn(self, xp):
123123

124124
def test_rfft(self, xp):
125125
x = xp.asarray(random(29), dtype=xp.float64)
126-
for n in [size(x), 2*size(x)]:
126+
for n in [xp_size(x), 2*xp_size(x)]:
127127
for norm in [None, "backward", "ortho", "forward"]:
128128
xp_assert_close(fft.rfft(x, n=n, norm=norm),
129129
fft.fft(xp.asarray(x, dtype=xp.complex128),
@@ -285,7 +285,7 @@ def test_all_1d_norm_preserving(self, xp):
285285
x = xp.asarray(random(30), dtype=xp.float64)
286286
xp_test = array_namespace(x)
287287
x_norm = xp_test.linalg.vector_norm(x)
288-
n = size(x) * 2
288+
n = xp_size(x) * 2
289289
func_pairs = [(fft.rfft, fft.irfft),
290290
# hfft: order so the first function takes x.size samples
291291
# (necessary for comparison to x_norm above)
@@ -297,7 +297,7 @@ def test_all_1d_norm_preserving(self, xp):
297297
if forw == fft.fft:
298298
x = xp.asarray(x, dtype=xp.complex128)
299299
x_norm = xp_test.linalg.vector_norm(x)
300-
for n in [size(x), 2*size(x)]:
300+
for n in [xp_size(x), 2*xp_size(x)]:
301301
for norm in ['backward', 'ortho', 'forward']:
302302
tmp = forw(x, n=n, norm=norm)
303303
tmp = back(tmp, n=n, norm=norm)

0 commit comments

Comments
 (0)