Skip to content

Commit dc7dafe

Browse files
authored
Merge pull request numpy#19869 from mhvk/median_scalar_nan
BUG: ensure np.median does not drop subclass for NaN result.
2 parents cd9a4f5 + 9377d36 commit dc7dafe

File tree

4 files changed

+32
-25
lines changed

4 files changed

+32
-25
lines changed

numpy/lib/function_base.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3714,16 +3714,15 @@ def _median(a, axis=None, out=None, overwrite_input=False):
37143714
indexer[axis] = slice(index-1, index+1)
37153715
indexer = tuple(indexer)
37163716

3717+
# Use mean in both odd and even case to coerce data type,
3718+
# using out array if needed.
3719+
rout = mean(part[indexer], axis=axis, out=out)
37173720
# Check if the array contains any nan's
37183721
if np.issubdtype(a.dtype, np.inexact) and sz > 0:
3719-
# warn and return nans like mean would
3720-
rout = mean(part[indexer], axis=axis, out=out)
3721-
return np.lib.utils._median_nancheck(part, rout, axis, out)
3722-
else:
3723-
# if there are no nans
3724-
# Use mean in odd and even case to coerce data type
3725-
# and check, use out array.
3726-
return mean(part[indexer], axis=axis, out=out)
3722+
# If nans are possible, warn and replace by nans like mean would.
3723+
rout = np.lib.utils._median_nancheck(part, rout, axis)
3724+
3725+
return rout
37273726

37283727

37293728
def _percentile_dispatcher(a, q, axis=None, out=None, overwrite_input=None,

numpy/lib/tests/test_function_base.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3432,6 +3432,16 @@ def mean(self, axis=None, dtype=None, out=None):
34323432
a = MySubClass([1, 2, 3])
34333433
assert_equal(np.median(a), -7)
34343434

3435+
@pytest.mark.parametrize('arr',
3436+
([1., 2., 3.], [1., np.nan, 3.], np.nan, 0.))
3437+
def test_subclass2(self, arr):
3438+
"""Check that we return subclasses, even if a NaN scalar."""
3439+
class MySubclass(np.ndarray):
3440+
pass
3441+
3442+
m = np.median(np.array(arr).view(MySubclass))
3443+
assert isinstance(m, MySubclass)
3444+
34353445
def test_out(self):
34363446
o = np.zeros((4,))
34373447
d = np.ones((3, 4))

numpy/lib/utils.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,41 +1002,39 @@ def safe_eval(source):
10021002
return ast.literal_eval(source)
10031003

10041004

1005-
def _median_nancheck(data, result, axis, out):
1005+
def _median_nancheck(data, result, axis):
10061006
"""
10071007
Utility function to check median result from data for NaN values at the end
10081008
and return NaN in that case. Input result can also be a MaskedArray.
10091009
10101010
Parameters
10111011
----------
10121012
data : array
1013-
Input data to median function
1013+
Sorted input data to median function
10141014
result : Array or MaskedArray
1015-
Result of median function
1015+
Result of median function.
10161016
axis : int
10171017
Axis along which the median was computed.
1018-
out : ndarray, optional
1019-
Output array in which to place the result.
10201018
10211019
Returns
10221020
-------
1023-
median : scalar or ndarray
1024-
Median or NaN in axes which contained NaN in the input.
1021+
result : scalar or ndarray
1022+
Median or NaN in axes which contained NaN in the input. If the input
1023+
was an array, NaN will be inserted in-place. If a scalar, either the
1024+
input itself or a scalar NaN.
10251025
"""
10261026
if data.size == 0:
10271027
return result
10281028
n = np.isnan(data.take(-1, axis=axis))
10291029
# masked NaN values are ok
10301030
if np.ma.isMaskedArray(n):
10311031
n = n.filled(False)
1032-
if result.ndim == 0:
1033-
if n == True:
1034-
if out is not None:
1035-
out[...] = data.dtype.type(np.nan)
1036-
result = out
1037-
else:
1038-
result = data.dtype.type(np.nan)
1039-
elif np.count_nonzero(n.ravel()) > 0:
1032+
if np.count_nonzero(n.ravel()) > 0:
1033+
# Without given output, it is possible that the current result is a
1034+
# numpy scalar, which is not writeable. If so, just return nan.
1035+
if isinstance(result, np.generic):
1036+
return data.dtype.type(np.nan)
1037+
10401038
result[n] = np.nan
10411039
return result
10421040

numpy/ma/extras.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -750,7 +750,7 @@ def _median(a, axis=None, out=None, overwrite_input=False):
750750
s = mid.sum(out=out)
751751
if not odd:
752752
s = np.true_divide(s, 2., casting='safe', out=out)
753-
s = np.lib.utils._median_nancheck(asorted, s, axis, out)
753+
s = np.lib.utils._median_nancheck(asorted, s, axis)
754754
else:
755755
s = mid.mean(out=out)
756756

@@ -790,7 +790,7 @@ def replace_masked(s):
790790
s = np.ma.sum(low_high, axis=axis, out=out)
791791
np.true_divide(s.data, 2., casting='unsafe', out=s.data)
792792

793-
s = np.lib.utils._median_nancheck(asorted, s, axis, out)
793+
s = np.lib.utils._median_nancheck(asorted, s, axis)
794794
else:
795795
s = np.ma.mean(low_high, axis=axis, out=out)
796796

0 commit comments

Comments
 (0)