Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 14 additions & 21 deletions dpnp/dpnp_utils/dpnp_utils_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ def _calc_median(a, axis, out=None):
return res


def _calc_nanmedian(a, axis, out=None):
def _calc_nanmedian(a, out=None):
"""Compute the median of an array along a specified axis, ignoring NaNs."""
mask = dpnp.isnan(a)
valid_counts = dpnp.sum(~mask, axis=axis)
valid_counts = dpnp.sum(~mask, axis=-1)
if out is None:
res = dpnp.empty_like(valid_counts, dtype=a.dtype)
else:
Expand All @@ -76,27 +76,19 @@ def _calc_nanmedian(a, axis, out=None):
)
res = out

# Iterate over all indices of the output shape
for idx in dpnp.ndindex(res.shape):
current_valid_counts = valid_counts[idx]
left = (valid_counts - 1) // 2
right = valid_counts // 2

if current_valid_counts > 0:
# Extract the corresponding slice from the last axis of `a`
data = a[idx][:current_valid_counts]
left = (current_valid_counts - 1) // 2
right = current_valid_counts // 2
left_data = dpnp.take_along_axis(a, left[..., None], axis=-1)
right_data = dpnp.take_along_axis(a, right[..., None], axis=-1)
res = dpnp.where(
valid_counts[..., None] > 0, (left_data + right_data) / 2.0, dpnp.nan
)

if left == right:
res[idx] = data[left]
else:
res[idx] = (data[left] + data[right]) / 2.0
else:
warnings.warn(
"All-NaN slice encountered", RuntimeWarning, stacklevel=6
)
res[idx] = dpnp.nan
if mask.all(axis=-1).any():
warnings.warn("All-NaN slice encountered", RuntimeWarning, stacklevel=6)

return res
return dpnp.squeeze(res)


def _flatten_array_along_axes(a, axes_to_flatten, overwrite_input):
Expand Down Expand Up @@ -232,7 +224,8 @@ def dpnp_median(

if ignore_nan:
# sorting puts NaNs at the end
res = _calc_nanmedian(a_sorted, axis=axis, out=out)
assert axis == -1
res = _calc_nanmedian(a_sorted, out=out)
else:
# We can't pass keepdims and use it in dpnp.mean and dpnp.any
# because of the reshape hack that might have been used in
Expand Down
Loading