diff --git a/scripts/statistics/median.m b/scripts/statistics/median.m index e2094e3724..fe5944d7af 100644 --- a/scripts/statistics/median.m +++ b/scripts/statistics/median.m @@ -131,6 +131,7 @@ szx = sz_out = size (x); ndx = ndims (x); outtype = class (x); + xsparse = issparse (x); if (nvarg > 1 && ! varg_chars(2:end)) ## Only first varargin can be numeric @@ -266,7 +267,7 @@ switch (outtype) case {"double", "single"} m = NaN (sz_out, outtype); - if (issparse (x)) + if (xsparse) m = sparse (m); endif case ("logical") @@ -280,7 +281,7 @@ if (all (isnan (x(:)))) ## all NaN input, output single or double NaNs in pre-determined size m = NaN (sz_out, outtype); - if (issparse (x)) + if (xsparse) m = sparse (m); endif return; @@ -296,7 +297,7 @@ return; endif - ## Permute dim to simplify all operations along dim1. At func. end ipermute. + ## Permute dim to simplify all operations along dim1 if (numel (dim) > 1 || (dim != 1 && ! isvector (x))) perm = 1 : ndx; @@ -333,124 +334,207 @@ omitnan = false; endif - x = sort (x, dim); # Note: pushes any NaN's to end for omitnan compatibility + ## Sparse inputs use the original sort-based code path to preserve sparsity. + ## Dense inputs use nth_element for O(n) selection instead of O(n log n) sort. - if (omitnan) - ## Ignore any NaN's in data. Each operating vector might have a - ## different number of non-NaN data points. + if (xsparse) + # use sort for sparse matrices to retain sparsity of output + x = sort (x, dim); - if (isvector (x)) - ## Checks above ensure either dim1 or dim2 vector - x = x(! isnan (x)); - n = numel (x); - k = floor ((n + 1) / 2); - if (mod (n, 2)) - ## odd - m = x(k); + if (omitnan) + if (isvector (x)) + x = x(! isnan (x)); + n = numel (x); + k = floor ((n + 1) / 2); + if (n == 0) + m = sparse (NaN); + elseif (mod (n, 2)) + m = sparse (x(k)); + else + m = sparse ((x(k) + x(k + 1)) / 2); + endif else - ## even - m = (x(k) + x(k + 1)) / 2; + n = sum (! isnan (x), 1)(:); + k = floor ((n + 1) / 2); + odd_cols = mod (n, 2) & n; + even_cols = ! odd_cols & n; + + m = sparse (NaN ([1, szx(2 : end)])); + + if (ndims (x) > 2) + szx_flat = [szx(1), prod(szx(2 : end))]; + else + szx_flat = szx; + endif + + if (any (odd_cols)) + idx = sub2ind (szx_flat, k(odd_cols), find (odd_cols)); + m(odd_cols) = x(idx); + endif + if (any (even_cols)) + k_even = k(even_cols); + idx = sub2ind (szx_flat, [k_even, k_even + 1], ... + find (even_cols)(:, [1, 1])); + m(even_cols) = sum (x(idx), 2) / 2; + endif endif else - ## Each column may have a different n and k. Force index column vector - ## for consistent orientation for 2-D and N-D inputs, then use sub2ind to - ## get correct element(s) for each column. + ## No "omitnan" for sparse + if (all (! nanfree)) + m = NaN (sz_out); + m = sparse (m); - n = sum (! isnan (x), 1)(:); - k = floor ((n + 1) / 2); - m_idx_odd = mod (n, 2) & n; - m_idx_even = (! m_idx_odd) & n; + else + if (isvector (x)) + n = numel (x); + k = floor ((n + 1) / 2); + + m = x(k); + if (! mod (n, 2)) + if (any (isinf ([x(k), x(k+1)]))) + m = x(k) + x(k+1); + else + m += (x(k + 1) - m) / 2; + endif + endif + m = sparse (m); - m = NaN ([1, szx(2 : end)]); - if (issparse (x)) - m = sparse (m); - endif + else + n = szx(1); + k = floor ((n + 1) / 2); - if (ndims (x) > 2) - szx = [szx(1), prod(szx(2 : end))]; - endif + m = sparse (NaN ([1, szx(2 : end)])); - ## Grab kth value, k possibly different for each column - if (any (m_idx_odd)) - x_idx_odd = sub2ind (szx, k(m_idx_odd), find (m_idx_odd)); - m(m_idx_odd) = x(x_idx_odd); - endif - if (any (m_idx_even)) - k_even = k(m_idx_even); - x_idx_even = sub2ind (szx, [k_even, k_even + 1], ... - (find (m_idx_even))(:, [1, 1])); - m(m_idx_even) = sum (x(x_idx_even), 2) / 2; + if (! mod (n, 2)) + m(nanfree) = (x(k, nanfree) + x(k + 1, nanfree)) / 2; + else + m(nanfree) = x(k, nanfree); + endif + endif endif endif else - ## No "omitnan". All 'vectors' uniform length. - ## All types without a NaN value will use this path. - if (all (! nanfree)) - m = NaN (sz_out); - if (issparse (x)) - m = sparse (m); - endif - - else + ## dense: use nth_element for O(n) selection + if (omitnan) + ## Ignore any NaN's in data. + ## Each operating vector might have a different number of non-NaN data points. if (isvector (x)) + ## Checks above ensure either dim1 or dim2 vector + x = x(! isnan (x)); n = numel (x); - k = floor ((n + 1) / 2); - - m = x(k); - if (! mod (n, 2)) - ## Even - if (any (isinf ([x(k), x(k+1)]))) - ## If either center value is Inf, replace m by +/-Inf or NaN. - m = x(k) + x(k+1); - elseif (any (isa (x, "integer"))) - ## avoid int overflow issues - m2 = x(k + 1); - if (sign (m) != sign (m2)) - m += m2; - m /= 2; - else - m += (m2 - m) / 2; - endif + if (n == 0) + m = NaN (sz_out, outtype); + else + k = floor ((n + 1) / 2); + if (mod (n, 2)) + m = nth_element (x, k); else - m += (x(k + 1) - m) / 2; + vals = nth_element (x, [k, k + 1]); + m = mid_two_vals (vals(1), vals(2), isa (x, "integer")); endif endif else - ## Nonvector, all operations were permuted to be along dim 1 + ## Columns may have different non-NaN counts; process individually. n = szx(1); - k = floor ((n + 1) / 2); + rest_sz = szx(2 : end); + ncols = prod (rest_sz); + + if (ndims (x) > 2) + x = reshape (x, [n, ncols]); + endif if (isfloat (x)) - m = NaN ([1, szx(2 : end)]); - if (issparse (x)) - m = sparse (m); - endif + m = NaN (1, ncols); else - m = zeros ([1, szx(2 : end)], outtype); + m = zeros (1, ncols, outtype); endif - if (! mod (n, 2)) - ## Even - if (any (isa (x, "integer"))) - ## avoid int overflow issues + x = reshape (x, [n, ncols]); - ## Use flattened index to simplify N-D operations - m(1, :) = x(k, :); - m2 = x(k + 1, :); + for j = 1:ncols + col = x(:, j); + col = col(! isnan (col)); + ncol = numel (col); - samesign = prod (sign ([m(1, :); m2]), 1) == 1; - m(1, :) = samesign .* m(1, :) + ... - (m2 + !samesign .* m(1, :) - samesign .* m(1, :)) / 2; + if (ncol == 0) + continue; + endif + k = floor ((ncol + 1) / 2); + if (mod (ncol, 2)) + m(j) = nth_element (col, k); else - m(nanfree) = (x(k, nanfree) + x(k + 1, nanfree)) / 2; + vals = nth_element (col, [k, k + 1]); + m(j) = mid_two_vals (vals(1), vals(2), isa (x, "integer")); + endif + endfor + + if (numel (rest_sz) > 1) + m = reshape (m, [1, rest_sz]); + endif + endif + + else + ## No "omitnan". All types without a NaN value will use this path. + if (all (! nanfree)) + m = NaN (sz_out); + + else + if (isvector (x)) + n = numel (x); + k = floor ((n + 1) / 2); + + if (! nanfree) + m = NaN (sz_out); + else + if (mod (n, 2)) + ## Odd + m = nth_element (x, k); + else + ## Even + vals = nth_element (x, [k, k + 1]); + m = mid_two_vals (vals(1), vals(2), isa (x, "integer")); + endif endif + else - ## Odd. Use flattened index to simplify N-D operations - m(nanfree) = x(k, nanfree); + ## Nonvector, all operations were permuted to be along dim 1 + n = szx(1); + k = floor ((n + 1) / 2); + rest_sz = szx(2 : end); + ncols = prod (rest_sz); + + if (isfloat (x)) + m = NaN (1, ncols); + else + m = zeros (1, ncols, outtype); + endif + + if (ndims (x) > 2) + x = reshape (x, [n, ncols]); + nanfree = reshape (nanfree, [1, ncols]); + endif + + if (mod (n, 2)) + ## Odd. Use flattened index to simplify N-D operations + if (any (nanfree(:))) + vals = nth_element (x(:, nanfree), k, 1); + m(nanfree) = vals; + endif + else + ## Even + if (any (nanfree(:))) + vals = nth_element (x(:, nanfree), [k, k + 1], 1); + m(nanfree) = mid_two_vals (vals(1, :), vals(2, :), isa (x, "integer")); + endif + endif + + if (numel (rest_sz) > 1) + m = reshape (m, [1, rest_sz]); + endif endif endif endif @@ -469,6 +553,37 @@ endfunction +## Compute mean of two middle values, handling Inf and integer overflow. +function m = mid_two_vals (m1, m2, is_int) + if (is_int) + samesign = sign (m1) == sign (m2); + m = zeros (size (m1), "like", m1); + m(samesign) = m1(samesign) + (m2(samesign) - m1(samesign)) / 2; + m(! samesign) = (m1(! samesign) + m2(! samesign)) / 2; + else + m = (m1 + m2) / 2; + endif +endfunction + + +## Tests for per-element Inf handling in even-length median +%!test +%! x = [-Inf, 2; 1, 3]; +%! assert (median (x, 1), [-Inf, 2.5]); + +%!test +%! x = [-Inf, Inf; Inf, -Inf]; +%! assert (median (x, 1), [NaN, NaN]); + +%!test +%! x = [1, Inf, 3; 5, 7, 9]; +%! assert (median (x, 1), [3, Inf, 6]); + +%!test +%! x = int64 ([10, 20; 30, 40]); +%! assert (median (x, 1), int64 ([20, 30])); + + %!assert (median (1), 1) %!assert (median ([1, 2, 3]), 2) %!assert (median ([1, 2, 3]'), 2)