Skip to content

Commit 64a49b0

Browse files
authored
Merge pull request scipy#21264 from lucascolley/xp-compat-1.8
MAINT: utilise `array_api_compat` v1.8
2 parents e003d8c + 468a09c commit 64a49b0

File tree

9 files changed

+36
-74
lines changed

9 files changed

+36
-74
lines changed

scipy/_lib/_array_api.py

Lines changed: 1 addition & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -155,18 +155,14 @@ def _asarray(
155155
"""
156156
if xp is None:
157157
xp = array_namespace(array)
158-
if xp.__name__ in {"numpy", "scipy._lib.array_api_compat.numpy"}:
158+
if is_numpy(xp):
159159
# Use NumPy API to support order
160160
if copy is True:
161161
array = np.array(array, order=order, dtype=dtype, subok=subok)
162162
elif subok:
163163
array = np.asanyarray(array, order=order, dtype=dtype)
164164
else:
165165
array = np.asarray(array, order=order, dtype=dtype)
166-
167-
# At this point array is a NumPy ndarray. We convert it to an array
168-
# container that is consistent with the input's namespace.
169-
array = xp.asarray(array)
170166
else:
171167
try:
172168
array = xp.asarray(array, dtype=dtype, copy=copy)
@@ -493,41 +489,6 @@ def scipy_namespace_for(xp: ModuleType) -> ModuleType | None:
493489
return None
494490

495491

496-
# temporary substitute for xp.minimum, which is not yet in all backends
497-
# or covered by array_api_compat.
498-
def xp_minimum(x1: Array, x2: Array, /) -> Array:
499-
# xp won't be passed in because it doesn't need to be passed in to xp.minimum
500-
xp = array_namespace(x1, x2)
501-
if hasattr(xp, 'minimum'):
502-
return xp.minimum(x1, x2)
503-
x1, x2 = xp.broadcast_arrays(x1, x2)
504-
i = (x2 < x1) | xp.isnan(x2)
505-
res = xp.where(i, x2, x1)
506-
return res[()] if res.ndim == 0 else res
507-
508-
509-
# temporary substitute for xp.clip, which is not yet in all backends
510-
# or covered by array_api_compat.
511-
def xp_clip(
512-
x: Array,
513-
/,
514-
min: int | float | Array | None = None,
515-
max: int | float | Array | None = None,
516-
*,
517-
xp: ModuleType | None = None) -> Array:
518-
xp = array_namespace(x) if xp is None else xp
519-
a, b = xp.asarray(min, dtype=x.dtype), xp.asarray(max, dtype=x.dtype)
520-
if hasattr(xp, 'clip'):
521-
return xp.clip(x, a, b)
522-
x, a, b = xp.broadcast_arrays(x, a, b)
523-
y = xp.asarray(x, copy=True)
524-
ia = y < a
525-
y[ia] = a[ia]
526-
ib = y > b
527-
y[ib] = b[ib]
528-
return y[()] if y.ndim == 0 else y
529-
530-
531492
# temporary substitute for xp.moveaxis, which is not yet in all backends
532493
# or covered by array_api_compat.
533494
def xp_moveaxis_to_end(

scipy/optimize/_chandrupatla.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
import numpy as np
33
import scipy._lib._elementwise_iterative_method as eim
44
from scipy._lib._util import _RichResult
5-
from scipy._lib._array_api import (xp_clip, xp_minimum, xp_sign, copy as xp_copy,
6-
xp_take_along_axis)
5+
from scipy._lib._array_api import xp_sign, copy as xp_copy, xp_take_along_axis
76

87
# TODO:
98
# - (maybe?) don't use fancy indexing assignment
@@ -142,7 +141,7 @@ def _chandrupatla(func, a, b, *, args=(), xatol=None, xrtol=None,
142141
xatol = 4*finfo.smallest_normal if xatol is None else xatol
143142
xrtol = 4*finfo.eps if xrtol is None else xrtol
144143
fatol = finfo.smallest_normal if fatol is None else fatol
145-
frtol = frtol * xp_minimum(xp.abs(f1), xp.abs(f2))
144+
frtol = frtol * xp.minimum(xp.abs(f1), xp.abs(f2))
146145
maxiter = (math.log2(finfo.max) - math.log2(finfo.smallest_normal)
147146
if maxiter is None else maxiter)
148147
work = _RichResult(x1=x1, f1=f1, x2=x2, f2=f2, x3=None, f3=None, t=0.5,
@@ -226,7 +225,7 @@ def post_termination_check(work):
226225
# [1] Figure 1 (last box; see also BASIC in appendix with comment
227226
# "Adjust T Away from the Interval Boundary")
228227
tl = 0.5 * work.tol / work.dx
229-
work.t = xp_clip(t, tl, 1 - tl)
228+
work.t = xp.clip(t, tl, 1 - tl)
230229

231230
def customize_result(res, shape):
232231
xl, xr, fl, fr = res['xl'], res['xr'], res['fl'], res['fr']

scipy/optimize/tests/test_chandrupatla.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import scipy._lib._elementwise_iterative_method as eim
77
from scipy.conftest import array_api_compatible
88
from scipy._lib._array_api import (array_namespace, xp_assert_close, xp_assert_equal,
9-
xp_assert_less, xp_minimum, is_numpy, is_cupy,
9+
xp_assert_less, is_numpy, is_cupy,
1010
xp_ravel, size as xp_size)
1111

1212
from scipy.optimize.elementwise import find_minimum, find_root
@@ -655,7 +655,7 @@ def f(*args, **kwargs):
655655
xp_assert_equal(res.fr, self.f(res.xr, *args_xp))
656656

657657
assert xp.all(xp.abs(res.fun[finite]) ==
658-
xp_minimum(xp.abs(res.fl[finite]),
658+
xp.minimum(xp.abs(res.fl[finite]),
659659
xp.abs(res.fr[finite])))
660660

661661
def test_flags(self, xp):
@@ -727,7 +727,7 @@ def test_convergence(self, xp):
727727
kwargs = kwargs0.copy()
728728
kwargs['frtol'] = 1e-3
729729
x1, x2 = bracket
730-
f0 = xp_minimum(xp.abs(self.f(x1, *args)), xp.abs(self.f(x2, *args)))
730+
f0 = xp.minimum(xp.abs(self.f(x1, *args)), xp.abs(self.f(x2, *args)))
731731
res1 = _chandrupatla_root(self.f, *bracket, **kwargs)
732732
xp_assert_less(xp.abs(res1.fun), 1e-3*f0)
733733
kwargs['frtol'] = 1e-6

scipy/stats/_morestats.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,8 @@
1515
from scipy._lib._array_api import (
1616
array_namespace,
1717
size as xp_size,
18-
xp_minimum,
1918
xp_moveaxis_to_end,
2019
xp_vector_norm,
21-
xp_clip
2220
)
2321

2422
from ._ansari_swilk_statistics import gscale, swilk
@@ -2899,7 +2897,7 @@ def bartlett(*samples, axis=0):
28992897
chi2 = _SimpleChi2(xp.asarray(k-1))
29002898
pvalue = _get_pvalue(T, chi2, alternative='greater', symmetric=False, xp=xp)
29012899

2902-
T = xp_clip(T, min=0., max=xp.inf)
2900+
T = xp.clip(T, min=0., max=xp.inf)
29032901
T = T[()] if T.ndim == 0 else T
29042902
pvalue = pvalue[()] if pvalue.ndim == 0 else pvalue
29052903

@@ -4103,7 +4101,7 @@ def circvar(samples, high=2*pi, low=0, axis=None, nan_policy='propagate'):
41034101
cos_mean = xp.mean(cos_samp, axis=axis)
41044102
hypotenuse = (sin_mean**2. + cos_mean**2.)**0.5
41054103
# hypotenuse can go slightly above 1 due to rounding errors
4106-
R = xp_minimum(xp.asarray(1.), hypotenuse)
4104+
R = xp.clip(hypotenuse, max=1.)
41074105

41084106
res = 1. - R
41094107
return res
@@ -4207,7 +4205,7 @@ def circstd(samples, high=2*pi, low=0, axis=None, nan_policy='propagate', *,
42074205
cos_mean = xp.mean(cos_samp, axis=axis) # [1] (2.2.3)
42084206
hypotenuse = (sin_mean**2. + cos_mean**2.)**0.5
42094207
# hypotenuse can go slightly above 1 due to rounding errors
4210-
R = xp_minimum(xp.asarray(1.), hypotenuse) # [1] (2.2.4)
4208+
R = xp.clip(hypotenuse, max=1.) # [1] (2.2.4)
42114209

42124210
res = (-2*xp.log(R))**0.5+0.0 # torch.pow returns -0.0 if R==1
42134211
if not normalize:

scipy/stats/_resampling.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
import inspect
99

1010
from scipy._lib._util import check_random_state, _rename_parameter, rng_integers
11-
from scipy._lib._array_api import (array_namespace, is_numpy, xp_minimum,
12-
xp_clip, xp_moveaxis_to_end)
11+
from scipy._lib._array_api import array_namespace, is_numpy, xp_moveaxis_to_end
1312
from scipy.special import ndtr, ndtri, comb, factorial
1413

1514
from ._common import ConfidenceInterval
@@ -996,15 +995,15 @@ def greater(null_distribution, observed):
996995
def two_sided(null_distribution, observed):
997996
pvalues_less = less(null_distribution, observed)
998997
pvalues_greater = greater(null_distribution, observed)
999-
pvalues = xp_minimum(pvalues_less, pvalues_greater) * 2
998+
pvalues = xp.minimum(pvalues_less, pvalues_greater) * 2
1000999
return pvalues
10011000

10021001
compare = {"less": less,
10031002
"greater": greater,
10041003
"two-sided": two_sided}
10051004

10061005
pvalues = compare[alternative](null_distribution, observed)
1007-
pvalues = xp_clip(pvalues, 0., 1., xp=xp)
1006+
pvalues = xp.clip(pvalues, 0., 1.)
10081007

10091008
return MonteCarloTestResult(observed, pvalues, null_distribution)
10101009

scipy/stats/_stats_py.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,16 @@
7070
from scipy import stats
7171
from scipy.optimize import root_scalar
7272
from scipy._lib._util import normalize_axis_index
73-
from scipy._lib._array_api import (array_namespace, is_numpy, atleast_nd,
74-
xp_clip, xp_moveaxis_to_end, xp_sign,
75-
xp_minimum, xp_vector_norm)
76-
from scipy._lib.array_api_compat import size as xp_size
73+
from scipy._lib._array_api import (
74+
_asarray,
75+
array_namespace,
76+
atleast_nd,
77+
is_numpy,
78+
size as xp_size,
79+
xp_moveaxis_to_end,
80+
xp_sign,
81+
xp_vector_norm,
82+
)
7783
from scipy._lib.deprecation import _deprecated
7884

7985

@@ -1518,9 +1524,8 @@ def _get_pvalue(statistic, distribution, alternative, symmetric=True, xp=None):
15181524
pvalue = distribution.sf(statistic)
15191525
elif alternative == 'two-sided':
15201526
pvalue = 2 * (distribution.sf(xp.abs(statistic)) if symmetric
1521-
else xp_minimum(distribution.cdf(statistic),
1522-
distribution.sf(statistic),
1523-
xp=xp))
1527+
else xp.minimum(distribution.cdf(statistic),
1528+
distribution.sf(statistic)))
15241529
else:
15251530
message = "`alternative` must be 'less', 'greater', or 'two-sided'."
15261531
raise ValueError(message)
@@ -4625,9 +4630,7 @@ def statistic(x, y, axis):
46254630
# Presumably, if abs(r) > 1, then it is only some small artifact of
46264631
# floating point arithmetic.
46274632
one = xp.asarray(1, dtype=dtype)
4628-
# `clip` only recently added to array API, so it's not yet available in
4629-
# array_api_strict. Replace with e.g. `xp.clip(r, -one, one)` when available.
4630-
r = xp.asarray(xp_clip(r, -one, one, xp=xp))
4633+
r = xp.asarray(xp.clip(r, -one, one))
46314634
r[const_xy] = xp.nan
46324635

46334636
# As explained in the docstring, the distribution of `r` under the null
@@ -7073,7 +7076,7 @@ def warn_masked(arg):
70737076
f_obs_sum = _m_sum(f_obs_float, axis=axis, preserve_mask=False, xp=xp)
70747077
f_exp_sum = _m_sum(f_exp, axis=axis, preserve_mask=False, xp=xp)
70757078
relative_diff = (xp.abs(f_obs_sum - f_exp_sum) /
7076-
xp_minimum(f_obs_sum, f_exp_sum))
7079+
xp.minimum(f_obs_sum, f_exp_sum))
70777080
diff_gt_tol = xp.any(relative_diff > rtol, axis=None)
70787081
if diff_gt_tol:
70797082
msg = (f"For each axis slice, the sum of the observed "
@@ -10444,7 +10447,7 @@ def _xp_mean(x, /, *, axis=None, weights=None, keepdims=False, nan_policy='propa
1044410447
"""
1044510448
# ensure that `x` and `weights` are array-API compatible arrays of identical shape
1044610449
xp = array_namespace(x) if xp is None else xp
10447-
x = xp.asarray(x, dtype=dtype)
10450+
x = _asarray(x, dtype=dtype, subok=True)
1044810451
weights = xp.asarray(weights, dtype=dtype) if weights is not None else weights
1044910452

1045010453
# to ensure that this matches the behavior of decorated functions when one of the
@@ -10530,7 +10533,7 @@ def _xp_var(x, /, *, axis=None, correction=0, keepdims=False, nan_policy='propag
1053010533
# an array-api compatible function for variance with scipy.stats interface
1053110534
# and features (e.g. `nan_policy`).
1053210535
xp = array_namespace(x) if xp is None else xp
10533-
x = xp.asarray(x)
10536+
x = _asarray(x, subok=True)
1053410537

1053510538
# use `_xp_mean` instead of `xp.var` for desired warning behavior
1053610539
# it would be nice to combine this with `_var`, which uses `_moment`
@@ -10540,7 +10543,7 @@ def _xp_var(x, /, *, axis=None, correction=0, keepdims=False, nan_policy='propag
1054010543
# be easy.
1054110544
kwargs = dict(axis=axis, nan_policy=nan_policy, dtype=dtype, xp=xp)
1054210545
mean = _xp_mean(x, keepdims=True, **kwargs)
10543-
x = xp.asarray(x, dtype=mean.dtype)
10546+
x = _asarray(x, dtype=mean.dtype, subok=True)
1054410547
x_mean = _demean(x, mean, axis, xp=xp)
1054510548
var = _xp_mean(x_mean**2, keepdims=keepdims, **kwargs)
1054610549

tools/check_installation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@
4444
"_lib/array_api_compat/tests/test_array_namespace.py",
4545
"_lib/array_api_compat/tests/test_common.py",
4646
"_lib/array_api_compat/tests/test_isdtype.py",
47+
"_lib/array_api_compat/tests/test_no_dependencies.py",
4748
"_lib/array_api_compat/tests/test_vendoring.py",
48-
"_lib/array_api_compat/tests/test_array_namespace.py",
4949
"cobyqa/cobyqa/tests/test_main.py",
5050
"cobyqa/cobyqa/tests/test_models.py",
5151
"cobyqa/cobyqa/tests/test_problem.py",
@@ -103,7 +103,7 @@ def main(install_dir, no_tests):
103103
if pyi_file not in installed_pyi_files.keys():
104104
if no_tests and "test" in scipy_pyi_files[pyi_file]:
105105
continue
106-
raise Exception("%s is not installed" % scipy_pyi_files[pyi_file])
106+
raise Exception(f"{scipy_pyi_files[pyi_file]} is not installed")
107107

108108
print("----------- All the necessary .pyi files were installed --------------")
109109

tools/check_test_name.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,9 @@ def is_misnamed_test_func(
7171
and node.name
7272
not in ("teardown_method", "setup_method",
7373
"teardown_class", "setup_class",
74-
"setup_module", "teardown_module")
74+
"setup_module", "teardown_module",
75+
"_test_dependency", # array_api_compat.tests.test_no_dependencies
76+
)
7577
)
7678

7779

0 commit comments

Comments
 (0)