Skip to content

Commit 8c3e1f8

Browse files
committed
BUG: fix _asarray for masked arrays
[skip cirrus] [skip circle]
1 parent 95bc8ec commit 8c3e1f8

File tree

2 files changed

+14
-11
lines changed

2 files changed

+14
-11
lines changed

scipy/_lib/_array_api.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -159,18 +159,14 @@ def _asarray(
159159
"""
160160
if xp is None:
161161
xp = array_namespace(array)
162-
if xp.__name__ in {"numpy", "scipy._lib.array_api_compat.numpy"}:
162+
if is_numpy(xp):
163163
# Use NumPy API to support order
164164
if copy is True:
165165
array = np.array(array, order=order, dtype=dtype, subok=subok)
166166
elif subok:
167167
array = np.asanyarray(array, order=order, dtype=dtype)
168168
else:
169169
array = np.asarray(array, order=order, dtype=dtype)
170-
171-
# At this point array is a NumPy ndarray. We convert it to an array
172-
# container that is consistent with the input's namespace.
173-
array = xp.asarray(array)
174170
else:
175171
try:
176172
array = xp.asarray(array, dtype=dtype, copy=copy)

scipy/stats/_stats_py.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,16 @@
7070
from scipy import stats
7171
from scipy.optimize import root_scalar
7272
from scipy._lib._util import normalize_axis_index
73-
from scipy._lib._array_api import (array_namespace, is_numpy, atleast_nd,
74-
xp_moveaxis_to_end, xp_sign, xp_vector_norm)
75-
from scipy._lib.array_api_compat import size as xp_size
73+
from scipy._lib._array_api import (
74+
_asarray,
75+
array_namespace,
76+
atleast_nd,
77+
is_numpy,
78+
size as xp_size,
79+
xp_moveaxis_to_end,
80+
xp_sign,
81+
xp_vector_norm,
82+
)
7683
from scipy._lib.deprecation import _deprecated
7784

7885

@@ -10440,7 +10447,7 @@ def _xp_mean(x, /, *, axis=None, weights=None, keepdims=False, nan_policy='propa
1044010447
"""
1044110448
# ensure that `x` and `weights` are array-API compatible arrays of identical shape
1044210449
xp = array_namespace(x) if xp is None else xp
10443-
x = xp.asarray(x, dtype=dtype)
10450+
x = _asarray(x, dtype=dtype, subok=True)
1044410451
weights = xp.asarray(weights, dtype=dtype) if weights is not None else weights
1044510452

1044610453
# to ensure that this matches the behavior of decorated functions when one of the
@@ -10526,7 +10533,7 @@ def _xp_var(x, /, *, axis=None, correction=0, keepdims=False, nan_policy='propag
1052610533
# an array-api compatible function for variance with scipy.stats interface
1052710534
# and features (e.g. `nan_policy`).
1052810535
xp = array_namespace(x) if xp is None else xp
10529-
x = xp.asarray(x)
10536+
x = _asarray(x, subok=True)
1053010537

1053110538
# use `_xp_mean` instead of `xp.var` for desired warning behavior
1053210539
# it would be nice to combine this with `_var`, which uses `_moment`
@@ -10536,7 +10543,7 @@ def _xp_var(x, /, *, axis=None, correction=0, keepdims=False, nan_policy='propag
1053610543
# be easy.
1053710544
kwargs = dict(axis=axis, nan_policy=nan_policy, dtype=dtype, xp=xp)
1053810545
mean = _xp_mean(x, keepdims=True, **kwargs)
10539-
x = xp.asarray(x, dtype=mean.dtype)
10546+
x = _asarray(x, dtype=mean.dtype, subok=True)
1054010547
x_mean = _demean(x, mean, axis, xp=xp)
1054110548
var = _xp_mean(x_mean**2, keepdims=keepdims, **kwargs)
1054210549

0 commit comments

Comments
 (0)