diff --git a/dpnp/dpnp_iface_nanfunctions.py b/dpnp/dpnp_iface_nanfunctions.py index c1f62dac85d5..a4029636d489 100644 --- a/dpnp/dpnp_iface_nanfunctions.py +++ b/dpnp/dpnp_iface_nanfunctions.py @@ -60,6 +60,36 @@ ] +def _replace_nan_no_mask(a, val): + """ + Replace NaNs in array `a` with `val`. + + If `a` is of inexact type, make a copy of `a`, replace NaNs with + the `val` value, and return the copy. If `a` is not of inexact type, + do nothing and return `a`. + + Parameters + ---------- + a : {dpnp.ndarray, usm_ndarray} + Input array. + val : float + NaN values are set to `val` before doing the operation. + + Returns + ------- + out : dpnp.ndarray + If `a` is of inexact type, return a copy of `a` with the NaNs + replaced by the fill value, otherwise return `a`. + + """ + + dpnp.check_supported_arrays_type(a) + if dpnp.issubdtype(a.dtype, dpnp.inexact): + return dpnp.nan_to_num(a, nan=val, posinf=dpnp.inf, neginf=-dpnp.inf) + + return a + + def _replace_nan(a, val): """ Replace NaNs in array `a` with `val`. @@ -107,6 +137,18 @@ def nanargmax(a, axis=None, out=None, *, keepdims=False): For full documentation refer to :obj:`numpy.nanargmax`. + Warning + ------- + This function synchronizes in order to test for all-NaN slices in the array. + This may harm performance in some applications. To avoid synchronization, + the user is recommended to filter NaNs themselves and use `dpnp.argmax` + on the filtered array. + + Warning + ------- + The results cannot be trusted if a slice contains only NaNs + and -Infs. + Parameters ---------- a : {dpnp.ndarray, usm_ndarray} @@ -136,8 +178,6 @@ def nanargmax(a, axis=None, out=None, *, keepdims=False): values ignoring NaNs. The returned array must have the default array index data type. For all-NaN slices ``ValueError`` is raised. - Warning: the results cannot be trusted if a slice contains only NaNs - and -Infs. Limitations ----------- @@ -181,6 +221,18 @@ def nanargmin(a, axis=None, out=None, *, keepdims=False): For full documentation refer to :obj:`numpy.nanargmin`. + Warning + ------- + This function synchronizes in order to test for all-NaN slices in the array. + This may harm performance in some applications. To avoid synchronization, + the user is recommended to filter NaNs themselves and use `dpnp.argmax` + on the filtered array. + + Warning + ------- + The results cannot be trusted if a slice contains only NaNs + and -Infs. + Parameters ---------- a : {dpnp.ndarray, usm_ndarray} @@ -210,8 +262,6 @@ def nanargmin(a, axis=None, out=None, *, keepdims=False): values ignoring NaNs. The returned array must have the default array index data type. For all-NaN slices ``ValueError`` is raised. - Warning: the results cannot be trusted if a slice contains only NaNs - and Infs. Limitations ----------- @@ -315,7 +365,7 @@ def nancumprod(a, axis=None, dtype=None, out=None): """ - a, _ = _replace_nan(a, 1) + a = _replace_nan_no_mask(a, 1.0) return dpnp.cumprod(a, axis=axis, dtype=dtype, out=out) @@ -385,7 +435,7 @@ def nancumsum(a, axis=None, dtype=None, out=None): """ - a, _ = _replace_nan(a, 0) + a = _replace_nan_no_mask(a, 0.0) return dpnp.cumsum(a, axis=axis, dtype=dtype, out=out) @@ -884,7 +934,7 @@ def nanprod( """ - a, _ = _replace_nan(a, 1) + a = _replace_nan_no_mask(a, 1.0) return dpnp.prod( a, axis=axis, @@ -988,7 +1038,7 @@ def nansum( """ - a, _ = _replace_nan(a, 0) + a = _replace_nan_no_mask(a, 0.0) return dpnp.sum( a, axis=axis,