Skip to content

Commit 6ba4872

Browse files
committed
BUG: ensure np.median does not drop subclass for NaN result.
Currently, np.median is almost completely safe for subclasses, except if the result is NaN. In that case, it assumes the result is a scalar and substitutes a NaN with the right dtype. This PR fixes that, since subclasses like astropy's Quantity generally use array scalars to preserve subclass information such as the unit.
1 parent e4f85b0 commit 6ba4872

File tree

2 files changed

+16
-8
lines changed

2 files changed

+16
-8
lines changed

numpy/lib/tests/test_function_base.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3432,6 +3432,16 @@ def mean(self, axis=None, dtype=None, out=None):
34323432
a = MySubClass([1, 2, 3])
34333433
assert_equal(np.median(a), -7)
34343434

3435+
@pytest.mark.parametrize('arr',
3436+
([1., 2., 3.], [1., np.nan, 3.], np.nan, 0.))
3437+
def test_subclass2(self, arr):
3438+
"""Check that we return subclasses, even if a NaN scalar."""
3439+
class MySubclass(np.ndarray):
3440+
pass
3441+
3442+
m = np.median(np.array(arr).view(MySubclass))
3443+
assert isinstance(m, MySubclass)
3444+
34353445
def test_out(self):
34363446
o = np.zeros((4,))
34373447
d = np.ones((3, 4))

numpy/lib/utils.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,14 +1029,12 @@ def _median_nancheck(data, result, axis, out):
10291029
# masked NaN values are ok
10301030
if np.ma.isMaskedArray(n):
10311031
n = n.filled(False)
1032-
if result.ndim == 0:
1033-
if n == True:
1034-
if out is not None:
1035-
out[...] = data.dtype.type(np.nan)
1036-
result = out
1037-
else:
1038-
result = data.dtype.type(np.nan)
1039-
elif np.count_nonzero(n.ravel()) > 0:
1032+
if np.count_nonzero(n.ravel()) > 0:
1033+
# Without given output, it is possible that the current result is a
1034+
# numpy scalar, which is not writeable. If so, just return nan.
1035+
if isinstance(result, np.generic):
1036+
return data.dtype.type(np.nan)
1037+
10401038
result[n] = np.nan
10411039
return result
10421040

0 commit comments

Comments
 (0)