70
70
from scipy import stats
71
71
from scipy .optimize import root_scalar
72
72
from scipy ._lib ._util import normalize_axis_index
73
- from scipy ._lib ._array_api import (array_namespace , is_numpy , atleast_nd ,
74
- xp_moveaxis_to_end , xp_sign , xp_vector_norm )
75
- from scipy ._lib .array_api_compat import size as xp_size
73
+ from scipy ._lib ._array_api import (
74
+ _asarray ,
75
+ array_namespace ,
76
+ atleast_nd ,
77
+ is_numpy ,
78
+ size as xp_size ,
79
+ xp_moveaxis_to_end ,
80
+ xp_sign ,
81
+ xp_vector_norm ,
82
+ )
76
83
from scipy ._lib .deprecation import _deprecated
77
84
78
85
@@ -10440,7 +10447,7 @@ def _xp_mean(x, /, *, axis=None, weights=None, keepdims=False, nan_policy='propa
10440
10447
"""
10441
10448
# ensure that `x` and `weights` are array-API compatible arrays of identical shape
10442
10449
xp = array_namespace (x ) if xp is None else xp
10443
- x = xp . asarray (x , dtype = dtype )
10450
+ x = _asarray (x , dtype = dtype , subok = True )
10444
10451
weights = xp .asarray (weights , dtype = dtype ) if weights is not None else weights
10445
10452
10446
10453
# to ensure that this matches the behavior of decorated functions when one of the
@@ -10526,7 +10533,7 @@ def _xp_var(x, /, *, axis=None, correction=0, keepdims=False, nan_policy='propag
10526
10533
# an array-api compatible function for variance with scipy.stats interface
10527
10534
# and features (e.g. `nan_policy`).
10528
10535
xp = array_namespace (x ) if xp is None else xp
10529
- x = xp . asarray ( x )
10536
+ x = _asarray ( x , subok = True )
10530
10537
10531
10538
# use `_xp_mean` instead of `xp.var` for desired warning behavior
10532
10539
# it would be nice to combine this with `_var`, which uses `_moment`
@@ -10536,7 +10543,7 @@ def _xp_var(x, /, *, axis=None, correction=0, keepdims=False, nan_policy='propag
10536
10543
# be easy.
10537
10544
kwargs = dict (axis = axis , nan_policy = nan_policy , dtype = dtype , xp = xp )
10538
10545
mean = _xp_mean (x , keepdims = True , ** kwargs )
10539
- x = xp . asarray (x , dtype = mean .dtype )
10546
+ x = _asarray (x , dtype = mean .dtype , subok = True )
10540
10547
x_mean = _demean (x , mean , axis , xp = xp )
10541
10548
var = _xp_mean (x_mean ** 2 , keepdims = keepdims , ** kwargs )
10542
10549
0 commit comments