@@ -57,10 +57,10 @@ def _calc_median(a, axis, out=None):
5757 return res
5858
5959
60- def _calc_nanmedian (a , axis , out = None ):
60+ def _calc_nanmedian (a , out = None ):
6161 """Compute the median of an array along a specified axis, ignoring NaNs."""
6262 mask = dpnp .isnan (a )
63- valid_counts = dpnp .sum (~ mask , axis = axis )
63+ valid_counts = dpnp .sum (~ mask , axis = - 1 )
6464 if out is None :
6565 res = dpnp .empty_like (valid_counts , dtype = a .dtype )
6666 else :
@@ -76,27 +76,19 @@ def _calc_nanmedian(a, axis, out=None):
7676 )
7777 res = out
7878
79- # Iterate over all indices of the output shape
80- for idx in dpnp .ndindex (res .shape ):
81- current_valid_counts = valid_counts [idx ]
79+ left = (valid_counts - 1 ) // 2
80+ right = valid_counts // 2
8281
83- if current_valid_counts > 0 :
84- # Extract the corresponding slice from the last axis of `a`
85- data = a [ idx ][: current_valid_counts ]
86- left = ( current_valid_counts - 1 ) // 2
87- right = current_valid_counts // 2
82+ left_data = dpnp . take_along_axis ( a , left [..., None ], axis = - 1 )
83+ right_data = dpnp . take_along_axis ( a , right [..., None ], axis = - 1 )
84+ res = dpnp . where (
85+ valid_counts [..., None ] > 0 , ( left_data + right_data ) / 2.0 , dpnp . nan
86+ )
8887
89- if left == right :
90- res [idx ] = data [left ]
91- else :
92- res [idx ] = (data [left ] + data [right ]) / 2.0
93- else :
94- warnings .warn (
95- "All-NaN slice encountered" , RuntimeWarning , stacklevel = 6
96- )
97- res [idx ] = dpnp .nan
88+ if mask .all (axis = - 1 ).any ():
89+ warnings .warn ("All-NaN slice encountered" , RuntimeWarning , stacklevel = 6 )
9890
99- return res
91+ return dpnp . squeeze ( res )
10092
10193
10294def _flatten_array_along_axes (a , axes_to_flatten , overwrite_input ):
@@ -232,7 +224,8 @@ def dpnp_median(
232224
233225 if ignore_nan :
234226 # sorting puts NaNs at the end
235- res = _calc_nanmedian (a_sorted , axis = axis , out = out )
227+ assert axis == - 1
228+ res = _calc_nanmedian (a_sorted , out = out )
236229 else :
237230 # We can't pass keepdims and use it in dpnp.mean and dpnp.any
238231 # because of the reshape hack that might have been used in
0 commit comments