6060]
6161
6262
63- def _replace_nan_no_mask (a , val ):
63+ def _replace_nan (a , val ):
6464 """
6565 Replace NaNs in array `a` with `val`.
6666
@@ -90,14 +90,14 @@ def _replace_nan_no_mask(a, val):
9090 return a
9191
9292
93- def _replace_nan (a , val ):
93+ def _replace_nan_test_axis (a , val , axis ):
9494 """
95- Replace NaNs in array `a` with `val`.
95+ Replace NaNs in array `a` with `val` and test for all-NaN slices .
9696
97- If `a` is of inexact type, make a copy of `a`, replace NaNs with
98- the `val` value, and return the copy together with a boolean mask
99- marking the locations where NaNs were present . If `a` is not of
100- inexact type, do nothing and return `a` together with a mask of None .
97+ If `a` is of inexact type, test along an axis for an all-NaN slice.
98+ If none are found, make a copy of `a`, replace NaNs with the `val`
99+ value, and return the copy . If `a` is not of inexact type, do
100+ nothing and return `a`.
101101
102102 Parameters
103103 ----------
@@ -108,27 +108,21 @@ def _replace_nan(a, val):
108108
109109 Returns
110110 -------
111- out : { dpnp.ndarray}
111+ out : dpnp.ndarray
112112 If `a` is of inexact type, return a copy of `a` with the NaNs
113113 replaced by the fill value, otherwise return `a`.
114- mask: {bool, None}
115- If `a` is of inexact type, return a boolean mask marking locations of
116- NaNs, otherwise return ``None``.
117114
118115 """
119116
120117 dpnp .check_supported_arrays_type (a )
121118 if dpnp .issubdtype (a .dtype , dpnp .inexact ):
122119 mask = dpnp .isnan (a )
123- if not dpnp .any (mask ):
124- mask = None
125- else :
126- a = dpnp .array (a , copy = True )
127- dpnp .copyto (a , val , where = mask )
128- else :
129- mask = None
120+ mask = dpnp .all (mask , axis = axis )
121+ if dpnp .any (mask ):
122+ raise ValueError ("All-NaN slice encountered" )
123+ return dpnp .nan_to_num (a , nan = val , posinf = dpnp .inf , neginf = - dpnp .inf )
130124
131- return a , mask
125+ return a
132126
133127
134128def nanargmax (a , axis = None , out = None , * , keepdims = False ):
@@ -197,11 +191,7 @@ def nanargmax(a, axis=None, out=None, *, keepdims=False):
197191
198192 """
199193
200- a , mask = _replace_nan (a , - dpnp .inf )
201- if mask is not None :
202- mask = dpnp .all (mask , axis = axis )
203- if dpnp .any (mask ):
204- raise ValueError ("All-NaN slice encountered" )
194+ a = _replace_nan_test_axis (a , - dpnp .inf , axis )
205195 return dpnp .argmax (a , axis = axis , out = out , keepdims = keepdims )
206196
207197
@@ -271,11 +261,7 @@ def nanargmin(a, axis=None, out=None, *, keepdims=False):
271261
272262 """
273263
274- a , mask = _replace_nan (a , dpnp .inf )
275- if mask is not None :
276- mask = dpnp .all (mask , axis = axis )
277- if dpnp .any (mask ):
278- raise ValueError ("All-NaN slice encountered" )
264+ a = _replace_nan_test_axis (a , dpnp .inf , axis )
279265 return dpnp .argmin (a , axis = axis , out = out , keepdims = keepdims )
280266
281267
@@ -345,7 +331,7 @@ def nancumprod(a, axis=None, dtype=None, out=None):
345331
346332 """
347333
348- a = _replace_nan_no_mask (a , 1.0 )
334+ a = _replace_nan (a , 1.0 )
349335 return dpnp .cumprod (a , axis = axis , dtype = dtype , out = out )
350336
351337
@@ -415,7 +401,7 @@ def nancumsum(a, axis=None, dtype=None, out=None):
415401
416402 """
417403
418- a = _replace_nan_no_mask (a , 0.0 )
404+ a = _replace_nan (a , 0.0 )
419405 return dpnp .cumsum (a , axis = axis , dtype = dtype , out = out )
420406
421407
@@ -914,7 +900,7 @@ def nanprod(
914900
915901 """
916902
917- a = _replace_nan_no_mask (a , 1.0 )
903+ a = _replace_nan (a , 1.0 )
918904 return dpnp .prod (
919905 a ,
920906 axis = axis ,
@@ -1018,7 +1004,7 @@ def nansum(
10181004
10191005 """
10201006
1021- a = _replace_nan_no_mask (a , 0.0 )
1007+ a = _replace_nan (a , 0.0 )
10221008 return dpnp .sum (
10231009 a ,
10241010 axis = axis ,
0 commit comments