Skip to content

Commit b3a66e8

Browse files
authored
Merge pull request numpy#19821 from BvB93/nanfunctions
BUG: Fixed an issue wherein certain `nan<x>` functions could fail for object arrays
2 parents 4af05ea + ecba713 commit b3a66e8

File tree

2 files changed

+89
-75
lines changed

2 files changed

+89
-75
lines changed

numpy/lib/nanfunctions.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,12 @@ def _remove_nan_1d(arr1d, overwrite_input=False):
160160
True if `res` can be modified in place, given the constraint on the
161161
input
162162
"""
163+
if arr1d.dtype == object:
164+
# object arrays do not support `isnan` (gh-9009), so make a guess
165+
c = np.not_equal(arr1d, arr1d, dtype=bool)
166+
else:
167+
c = np.isnan(arr1d)
163168

164-
c = np.isnan(arr1d)
165169
s = np.nonzero(c)[0]
166170
if s.size == arr1d.size:
167171
warnings.warn("All-NaN slice encountered", RuntimeWarning,
@@ -214,7 +218,11 @@ def _divide_by_count(a, b, out=None):
214218
return np.divide(a, b, out=out, casting='unsafe')
215219
else:
216220
if out is None:
217-
return a.dtype.type(a / b)
221+
# Precaution against reduced object arrays
222+
try:
223+
return a.dtype.type(a / b)
224+
except AttributeError:
225+
return a / b
218226
else:
219227
# This is questionable, but currently a numpy scalar can
220228
# be output to a zero dimensional array.
@@ -1551,7 +1559,13 @@ def nanvar(a, axis=None, dtype=None, out=None, ddof=0, keepdims=np._NoValue):
15511559

15521560
# Compute variance.
15531561
var = np.sum(sqr, axis=axis, dtype=dtype, out=out, keepdims=keepdims)
1554-
if var.ndim < cnt.ndim:
1562+
1563+
# Precaution against reduced object arrays
1564+
try:
1565+
var_ndim = var.ndim
1566+
except AttributeError:
1567+
var_ndim = np.ndim(var)
1568+
if var_ndim < cnt.ndim:
15551569
# Subclasses of ndarray may ignore keepdims, so check here.
15561570
cnt = cnt.squeeze(axis)
15571571
dof = cnt - ddof
@@ -1671,6 +1685,8 @@ def nanstd(a, axis=None, dtype=None, out=None, ddof=0, keepdims=np._NoValue):
16711685
keepdims=keepdims)
16721686
if isinstance(var, np.ndarray):
16731687
std = np.sqrt(var, out=var)
1674-
else:
1688+
elif hasattr(var, 'dtype'):
16751689
std = var.dtype.type(np.sqrt(var))
1690+
else:
1691+
std = np.sqrt(var)
16761692
return std

numpy/lib/tests/test_nanfunctions.py

Lines changed: 69 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -231,79 +231,77 @@ class MyNDArray(np.ndarray):
231231
assert_(res.shape == ())
232232

233233

234-
class TestNanFunctions_IntTypes:
235-
236-
int_types = (np.int8, np.int16, np.int32, np.int64, np.uint8,
237-
np.uint16, np.uint32, np.uint64)
234+
@pytest.mark.parametrize(
235+
"dtype",
236+
np.typecodes["AllInteger"] + np.typecodes["AllFloat"] + "O",
237+
)
238+
class TestNanFunctions_NumberTypes:
238239

239240
mat = np.array([127, 39, 93, 87, 46])
240-
241-
def integer_arrays(self):
242-
for dtype in self.int_types:
243-
yield self.mat.astype(dtype)
244-
245-
def test_nanmin(self):
246-
tgt = np.min(self.mat)
247-
for mat in self.integer_arrays():
248-
assert_equal(np.nanmin(mat), tgt)
249-
250-
def test_nanmax(self):
251-
tgt = np.max(self.mat)
252-
for mat in self.integer_arrays():
253-
assert_equal(np.nanmax(mat), tgt)
254-
255-
def test_nanargmin(self):
256-
tgt = np.argmin(self.mat)
257-
for mat in self.integer_arrays():
258-
assert_equal(np.nanargmin(mat), tgt)
259-
260-
def test_nanargmax(self):
261-
tgt = np.argmax(self.mat)
262-
for mat in self.integer_arrays():
263-
assert_equal(np.nanargmax(mat), tgt)
264-
265-
def test_nansum(self):
266-
tgt = np.sum(self.mat)
267-
for mat in self.integer_arrays():
268-
assert_equal(np.nansum(mat), tgt)
269-
270-
def test_nanprod(self):
271-
tgt = np.prod(self.mat)
272-
for mat in self.integer_arrays():
273-
assert_equal(np.nanprod(mat), tgt)
274-
275-
def test_nancumsum(self):
276-
tgt = np.cumsum(self.mat)
277-
for mat in self.integer_arrays():
278-
assert_equal(np.nancumsum(mat), tgt)
279-
280-
def test_nancumprod(self):
281-
tgt = np.cumprod(self.mat)
282-
for mat in self.integer_arrays():
283-
assert_equal(np.nancumprod(mat), tgt)
284-
285-
def test_nanmean(self):
286-
tgt = np.mean(self.mat)
287-
for mat in self.integer_arrays():
288-
assert_equal(np.nanmean(mat), tgt)
289-
290-
def test_nanvar(self):
291-
tgt = np.var(self.mat)
292-
for mat in self.integer_arrays():
293-
assert_equal(np.nanvar(mat), tgt)
294-
295-
tgt = np.var(mat, ddof=1)
296-
for mat in self.integer_arrays():
297-
assert_equal(np.nanvar(mat, ddof=1), tgt)
298-
299-
def test_nanstd(self):
300-
tgt = np.std(self.mat)
301-
for mat in self.integer_arrays():
302-
assert_equal(np.nanstd(mat), tgt)
303-
304-
tgt = np.std(self.mat, ddof=1)
305-
for mat in self.integer_arrays():
306-
assert_equal(np.nanstd(mat, ddof=1), tgt)
241+
mat.setflags(write=False)
242+
243+
nanfuncs = {
244+
np.nanmin: np.min,
245+
np.nanmax: np.max,
246+
np.nanargmin: np.argmin,
247+
np.nanargmax: np.argmax,
248+
np.nansum: np.sum,
249+
np.nanprod: np.prod,
250+
np.nancumsum: np.cumsum,
251+
np.nancumprod: np.cumprod,
252+
np.nanmean: np.mean,
253+
np.nanmedian: np.median,
254+
np.nanvar: np.var,
255+
np.nanstd: np.std,
256+
}
257+
nanfunc_ids = [i.__name__ for i in nanfuncs]
258+
259+
@pytest.mark.parametrize("nanfunc,func", nanfuncs.items(), ids=nanfunc_ids)
260+
def test_nanfunc(self, dtype, nanfunc, func):
261+
if nanfunc is np.nanprod and dtype == "e":
262+
pytest.xfail(reason="overflow encountered in reduce")
263+
264+
mat = self.mat.astype(dtype)
265+
tgt = func(mat)
266+
out = nanfunc(mat)
267+
268+
assert_almost_equal(out, tgt)
269+
if dtype == "O":
270+
assert type(out) is type(tgt)
271+
else:
272+
assert out.dtype == tgt.dtype
273+
274+
@pytest.mark.parametrize(
275+
"nanfunc,func",
276+
[(np.nanquantile, np.quantile), (np.nanpercentile, np.percentile)],
277+
ids=["nanquantile", "nanpercentile"],
278+
)
279+
def test_nanfunc_q(self, dtype, nanfunc, func):
280+
mat = self.mat.astype(dtype)
281+
tgt = func(mat, q=1)
282+
out = nanfunc(mat, q=1)
283+
284+
assert_almost_equal(out, tgt)
285+
if dtype == "O":
286+
assert type(out) is type(tgt)
287+
else:
288+
assert out.dtype == tgt.dtype
289+
290+
@pytest.mark.parametrize(
291+
"nanfunc,func",
292+
[(np.nanvar, np.var), (np.nanstd, np.std)],
293+
ids=["nanvar", "nanstd"],
294+
)
295+
def test_nanfunc_ddof(self, dtype, nanfunc, func):
296+
mat = self.mat.astype(dtype)
297+
tgt = func(mat, ddof=1)
298+
out = nanfunc(mat, ddof=1)
299+
300+
assert_almost_equal(out, tgt)
301+
if dtype == "O":
302+
assert type(out) is type(tgt)
303+
else:
304+
assert out.dtype == tgt.dtype
307305

308306

309307
class SharedNanFunctionsTestsMixin:

0 commit comments

Comments
 (0)