Skip to content

Commit 0494853

Browse files
committed
Reuse nan_to_num in nanargmax and nanargmin
`_replace_nan_no_mask` becomes `_replace_nan` and the responsibility of raising the error moves into `_replace_nan_test_axis`
1 parent 4552fe8 commit 0494853

File tree

1 file changed

+19
-33
lines changed

1 file changed

+19
-33
lines changed

dpnp/dpnp_iface_nanfunctions.py

Lines changed: 19 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
]
6161

6262

63-
def _replace_nan_no_mask(a, val):
63+
def _replace_nan(a, val):
6464
"""
6565
Replace NaNs in array `a` with `val`.
6666
@@ -90,14 +90,14 @@ def _replace_nan_no_mask(a, val):
9090
return a
9191

9292

93-
def _replace_nan(a, val):
93+
def _replace_nan_test_axis(a, val, axis):
9494
"""
95-
Replace NaNs in array `a` with `val`.
95+
Replace NaNs in array `a` with `val` and test for all-NaN slices.
9696
97-
If `a` is of inexact type, make a copy of `a`, replace NaNs with
98-
the `val` value, and return the copy together with a boolean mask
99-
marking the locations where NaNs were present. If `a` is not of
100-
inexact type, do nothing and return `a` together with a mask of None.
97+
If `a` is of inexact type, test along an axis for an all-NaN slice.
98+
If none are found, make a copy of `a`, replace NaNs with the `val`
99+
value, and return the copy. If `a` is not of inexact type, do
100+
nothing and return `a`.
101101
102102
Parameters
103103
----------
@@ -108,27 +108,21 @@ def _replace_nan(a, val):
108108
109109
Returns
110110
-------
111-
out : {dpnp.ndarray}
111+
out : dpnp.ndarray
112112
If `a` is of inexact type, return a copy of `a` with the NaNs
113113
replaced by the fill value, otherwise return `a`.
114-
mask: {bool, None}
115-
If `a` is of inexact type, return a boolean mask marking locations of
116-
NaNs, otherwise return ``None``.
117114
118115
"""
119116

120117
dpnp.check_supported_arrays_type(a)
121118
if dpnp.issubdtype(a.dtype, dpnp.inexact):
122119
mask = dpnp.isnan(a)
123-
if not dpnp.any(mask):
124-
mask = None
125-
else:
126-
a = dpnp.array(a, copy=True)
127-
dpnp.copyto(a, val, where=mask)
128-
else:
129-
mask = None
120+
mask = dpnp.all(mask, axis=axis)
121+
if dpnp.any(mask):
122+
raise ValueError("All-NaN slice encountered")
123+
return dpnp.nan_to_num(a, nan=val, posinf=dpnp.inf, neginf=-dpnp.inf)
130124

131-
return a, mask
125+
return a
132126

133127

134128
def nanargmax(a, axis=None, out=None, *, keepdims=False):
@@ -197,11 +191,7 @@ def nanargmax(a, axis=None, out=None, *, keepdims=False):
197191
198192
"""
199193

200-
a, mask = _replace_nan(a, -dpnp.inf)
201-
if mask is not None:
202-
mask = dpnp.all(mask, axis=axis)
203-
if dpnp.any(mask):
204-
raise ValueError("All-NaN slice encountered")
194+
a = _replace_nan_test_axis(a, -dpnp.inf, axis)
205195
return dpnp.argmax(a, axis=axis, out=out, keepdims=keepdims)
206196

207197

@@ -271,11 +261,7 @@ def nanargmin(a, axis=None, out=None, *, keepdims=False):
271261
272262
"""
273263

274-
a, mask = _replace_nan(a, dpnp.inf)
275-
if mask is not None:
276-
mask = dpnp.all(mask, axis=axis)
277-
if dpnp.any(mask):
278-
raise ValueError("All-NaN slice encountered")
264+
a = _replace_nan_test_axis(a, dpnp.inf, axis)
279265
return dpnp.argmin(a, axis=axis, out=out, keepdims=keepdims)
280266

281267

@@ -345,7 +331,7 @@ def nancumprod(a, axis=None, dtype=None, out=None):
345331
346332
"""
347333

348-
a = _replace_nan_no_mask(a, 1.0)
334+
a = _replace_nan(a, 1.0)
349335
return dpnp.cumprod(a, axis=axis, dtype=dtype, out=out)
350336

351337

@@ -415,7 +401,7 @@ def nancumsum(a, axis=None, dtype=None, out=None):
415401
416402
"""
417403

418-
a = _replace_nan_no_mask(a, 0.0)
404+
a = _replace_nan(a, 0.0)
419405
return dpnp.cumsum(a, axis=axis, dtype=dtype, out=out)
420406

421407

@@ -914,7 +900,7 @@ def nanprod(
914900
915901
"""
916902

917-
a = _replace_nan_no_mask(a, 1.0)
903+
a = _replace_nan(a, 1.0)
918904
return dpnp.prod(
919905
a,
920906
axis=axis,
@@ -1018,7 +1004,7 @@ def nansum(
10181004
10191005
"""
10201006

1021-
a = _replace_nan_no_mask(a, 0.0)
1007+
a = _replace_nan(a, 0.0)
10221008
return dpnp.sum(
10231009
a,
10241010
axis=axis,

0 commit comments

Comments
 (0)