Skip to content

Commit 95bc8ec

Browse files
committed
MAINT: utilise array_api_compat v1.8
[skip cirrus] [skip circle]
1 parent 40d3998 commit 95bc8ec

File tree

7 files changed

+23
-62
lines changed

7 files changed

+23
-62
lines changed

scipy/_lib/_array_api.py

Lines changed: 5 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,11 @@ def array_namespace(*arrays: Array) -> ModuleType:
126126

127127
_arrays = compliance_scipy(_arrays)
128128

129-
return array_api_compat.array_namespace(*_arrays)
129+
# data-apis/array-api-compat#168
130+
try: # return the wrapped namespace for NumPy arrays
131+
return array_api_compat.array_namespace(*_arrays, use_compat=True)
132+
except ValueError: # if the library is not wrapped, like array-api-strict
133+
return array_api_compat.array_namespace(*_arrays, use_compat=None)
130134

131135

132136
def _asarray(
@@ -493,41 +497,6 @@ def scipy_namespace_for(xp: ModuleType) -> ModuleType | None:
493497
return None
494498

495499

496-
# temporary substitute for xp.minimum, which is not yet in all backends
497-
# or covered by array_api_compat.
498-
def xp_minimum(x1: Array, x2: Array, /) -> Array:
499-
# xp won't be passed in because it doesn't need to be passed in to xp.minimum
500-
xp = array_namespace(x1, x2)
501-
if hasattr(xp, 'minimum'):
502-
return xp.minimum(x1, x2)
503-
x1, x2 = xp.broadcast_arrays(x1, x2)
504-
i = (x2 < x1) | xp.isnan(x2)
505-
res = xp.where(i, x2, x1)
506-
return res[()] if res.ndim == 0 else res
507-
508-
509-
# temporary substitute for xp.clip, which is not yet in all backends
510-
# or covered by array_api_compat.
511-
def xp_clip(
512-
x: Array,
513-
/,
514-
min: int | float | Array | None = None,
515-
max: int | float | Array | None = None,
516-
*,
517-
xp: ModuleType | None = None) -> Array:
518-
xp = array_namespace(x) if xp is None else xp
519-
a, b = xp.asarray(min, dtype=x.dtype), xp.asarray(max, dtype=x.dtype)
520-
if hasattr(xp, 'clip'):
521-
return xp.clip(x, a, b)
522-
x, a, b = xp.broadcast_arrays(x, a, b)
523-
y = xp.asarray(x, copy=True)
524-
ia = y < a
525-
y[ia] = a[ia]
526-
ib = y > b
527-
y[ib] = b[ib]
528-
return y[()] if y.ndim == 0 else y
529-
530-
531500
# temporary substitute for xp.moveaxis, which is not yet in all backends
532501
# or covered by array_api_compat.
533502
def xp_moveaxis_to_end(

scipy/optimize/_chandrupatla.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
import numpy as np
33
import scipy._lib._elementwise_iterative_method as eim
44
from scipy._lib._util import _RichResult
5-
from scipy._lib._array_api import (xp_clip, xp_minimum, xp_sign, copy as xp_copy,
6-
xp_take_along_axis)
5+
from scipy._lib._array_api import xp_sign, copy as xp_copy, xp_take_along_axis
76

87
# TODO:
98
# - (maybe?) don't use fancy indexing assignment
@@ -142,7 +141,7 @@ def _chandrupatla(func, a, b, *, args=(), xatol=None, xrtol=None,
142141
xatol = 4*finfo.smallest_normal if xatol is None else xatol
143142
xrtol = 4*finfo.eps if xrtol is None else xrtol
144143
fatol = finfo.smallest_normal if fatol is None else fatol
145-
frtol = frtol * xp_minimum(xp.abs(f1), xp.abs(f2))
144+
frtol = frtol * xp.minimum(xp.abs(f1), xp.abs(f2))
146145
maxiter = (math.log2(finfo.max) - math.log2(finfo.smallest_normal)
147146
if maxiter is None else maxiter)
148147
work = _RichResult(x1=x1, f1=f1, x2=x2, f2=f2, x3=None, f3=None, t=0.5,
@@ -226,7 +225,7 @@ def post_termination_check(work):
226225
# [1] Figure 1 (last box; see also BASIC in appendix with comment
227226
# "Adjust T Away from the Interval Boundary")
228227
tl = 0.5 * work.tol / work.dx
229-
work.t = xp_clip(t, tl, 1 - tl)
228+
work.t = xp.clip(t, tl, 1 - tl)
230229

231230
def customize_result(res, shape):
232231
xl, xr, fl, fr = res['xl'], res['xr'], res['fl'], res['fr']

scipy/optimize/tests/test_chandrupatla.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import scipy._lib._elementwise_iterative_method as eim
77
from scipy.conftest import array_api_compatible
88
from scipy._lib._array_api import (array_namespace, xp_assert_close, xp_assert_equal,
9-
xp_assert_less, xp_minimum, is_numpy, is_cupy,
9+
xp_assert_less, is_numpy, is_cupy,
1010
xp_ravel, size as xp_size)
1111

1212
from scipy.optimize.elementwise import find_minimum, find_root
@@ -655,7 +655,7 @@ def f(*args, **kwargs):
655655
xp_assert_equal(res.fr, self.f(res.xr, *args_xp))
656656

657657
assert xp.all(xp.abs(res.fun[finite]) ==
658-
xp_minimum(xp.abs(res.fl[finite]),
658+
xp.minimum(xp.abs(res.fl[finite]),
659659
xp.abs(res.fr[finite])))
660660

661661
def test_flags(self, xp):
@@ -727,7 +727,7 @@ def test_convergence(self, xp):
727727
kwargs = kwargs0.copy()
728728
kwargs['frtol'] = 1e-3
729729
x1, x2 = bracket
730-
f0 = xp_minimum(xp.abs(self.f(x1, *args)), xp.abs(self.f(x2, *args)))
730+
f0 = xp.minimum(xp.abs(self.f(x1, *args)), xp.abs(self.f(x2, *args)))
731731
res1 = _chandrupatla_root(self.f, *bracket, **kwargs)
732732
xp_assert_less(xp.abs(res1.fun), 1e-3*f0)
733733
kwargs['frtol'] = 1e-6

scipy/stats/_morestats.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,8 @@
1515
from scipy._lib._array_api import (
1616
array_namespace,
1717
size as xp_size,
18-
xp_minimum,
1918
xp_moveaxis_to_end,
2019
xp_vector_norm,
21-
xp_clip
2220
)
2321

2422
from ._ansari_swilk_statistics import gscale, swilk
@@ -2899,7 +2897,7 @@ def bartlett(*samples, axis=0):
28992897
chi2 = _SimpleChi2(xp.asarray(k-1))
29002898
pvalue = _get_pvalue(T, chi2, alternative='greater', symmetric=False, xp=xp)
29012899

2902-
T = xp_clip(T, min=0., max=xp.inf)
2900+
T = xp.clip(T, min=0., max=xp.inf)
29032901
T = T[()] if T.ndim == 0 else T
29042902
pvalue = pvalue[()] if pvalue.ndim == 0 else pvalue
29052903

@@ -4103,7 +4101,7 @@ def circvar(samples, high=2*pi, low=0, axis=None, nan_policy='propagate'):
41034101
cos_mean = xp.mean(cos_samp, axis=axis)
41044102
hypotenuse = (sin_mean**2. + cos_mean**2.)**0.5
41054103
# hypotenuse can go slightly above 1 due to rounding errors
4106-
R = xp_minimum(xp.asarray(1.), hypotenuse)
4104+
R = xp.clip(hypotenuse, max=1.)
41074105

41084106
res = 1. - R
41094107
return res
@@ -4207,7 +4205,7 @@ def circstd(samples, high=2*pi, low=0, axis=None, nan_policy='propagate', *,
42074205
cos_mean = xp.mean(cos_samp, axis=axis) # [1] (2.2.3)
42084206
hypotenuse = (sin_mean**2. + cos_mean**2.)**0.5
42094207
# hypotenuse can go slightly above 1 due to rounding errors
4210-
R = xp_minimum(xp.asarray(1.), hypotenuse) # [1] (2.2.4)
4208+
R = xp.clip(hypotenuse, max=1.) # [1] (2.2.4)
42114209

42124210
res = (-2*xp.log(R))**0.5+0.0 # torch.pow returns -0.0 if R==1
42134211
if not normalize:

scipy/stats/_resampling.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
import inspect
99

1010
from scipy._lib._util import check_random_state, _rename_parameter, rng_integers
11-
from scipy._lib._array_api import (array_namespace, is_numpy, xp_minimum,
12-
xp_clip, xp_moveaxis_to_end)
11+
from scipy._lib._array_api import array_namespace, is_numpy, xp_moveaxis_to_end
1312
from scipy.special import ndtr, ndtri, comb, factorial
1413

1514
from ._common import ConfidenceInterval
@@ -996,15 +995,15 @@ def greater(null_distribution, observed):
996995
def two_sided(null_distribution, observed):
997996
pvalues_less = less(null_distribution, observed)
998997
pvalues_greater = greater(null_distribution, observed)
999-
pvalues = xp_minimum(pvalues_less, pvalues_greater) * 2
998+
pvalues = xp.minimum(pvalues_less, pvalues_greater) * 2
1000999
return pvalues
10011000

10021001
compare = {"less": less,
10031002
"greater": greater,
10041003
"two-sided": two_sided}
10051004

10061005
pvalues = compare[alternative](null_distribution, observed)
1007-
pvalues = xp_clip(pvalues, 0., 1., xp=xp)
1006+
pvalues = xp.clip(pvalues, 0., 1.)
10081007

10091008
return MonteCarloTestResult(observed, pvalues, null_distribution)
10101009

scipy/stats/_stats_py.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,7 @@
7171
from scipy.optimize import root_scalar
7272
from scipy._lib._util import normalize_axis_index
7373
from scipy._lib._array_api import (array_namespace, is_numpy, atleast_nd,
74-
xp_clip, xp_moveaxis_to_end, xp_sign,
75-
xp_minimum, xp_vector_norm)
74+
xp_moveaxis_to_end, xp_sign, xp_vector_norm)
7675
from scipy._lib.array_api_compat import size as xp_size
7776
from scipy._lib.deprecation import _deprecated
7877

@@ -1518,9 +1517,8 @@ def _get_pvalue(statistic, distribution, alternative, symmetric=True, xp=None):
15181517
pvalue = distribution.sf(statistic)
15191518
elif alternative == 'two-sided':
15201519
pvalue = 2 * (distribution.sf(xp.abs(statistic)) if symmetric
1521-
else xp_minimum(distribution.cdf(statistic),
1522-
distribution.sf(statistic),
1523-
xp=xp))
1520+
else xp.minimum(distribution.cdf(statistic),
1521+
distribution.sf(statistic)))
15241522
else:
15251523
message = "`alternative` must be 'less', 'greater', or 'two-sided'."
15261524
raise ValueError(message)
@@ -4625,9 +4623,7 @@ def statistic(x, y, axis):
46254623
# Presumably, if abs(r) > 1, then it is only some small artifact of
46264624
# floating point arithmetic.
46274625
one = xp.asarray(1, dtype=dtype)
4628-
# `clip` only recently added to array API, so it's not yet available in
4629-
# array_api_strict. Replace with e.g. `xp.clip(r, -one, one)` when available.
4630-
r = xp.asarray(xp_clip(r, -one, one, xp=xp))
4626+
r = xp.asarray(xp.clip(r, -one, one))
46314627
r[const_xy] = xp.nan
46324628

46334629
# As explained in the docstring, the distribution of `r` under the null
@@ -7073,7 +7069,7 @@ def warn_masked(arg):
70737069
f_obs_sum = _m_sum(f_obs_float, axis=axis, preserve_mask=False, xp=xp)
70747070
f_exp_sum = _m_sum(f_exp, axis=axis, preserve_mask=False, xp=xp)
70757071
relative_diff = (xp.abs(f_obs_sum - f_exp_sum) /
7076-
xp_minimum(f_obs_sum, f_exp_sum))
7072+
xp.minimum(f_obs_sum, f_exp_sum))
70777073
diff_gt_tol = xp.any(relative_diff > rtol, axis=None)
70787074
if diff_gt_tol:
70797075
msg = (f"For each axis slice, the sum of the observed "

0 commit comments

Comments
 (0)