Skip to content
291 changes: 203 additions & 88 deletions scripts/statistics/median.m
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@
szx = sz_out = size (x);
ndx = ndims (x);
outtype = class (x);
xsparse = issparse (x);

if (nvarg > 1 && ! varg_chars(2:end))
## Only first varargin can be numeric
Expand Down Expand Up @@ -266,7 +267,7 @@
switch (outtype)
case {"double", "single"}
m = NaN (sz_out, outtype);
if (issparse (x))
if (xsparse)
m = sparse (m);
endif
case ("logical")
Expand All @@ -280,7 +281,7 @@
if (all (isnan (x(:))))
## all NaN input, output single or double NaNs in pre-determined size
m = NaN (sz_out, outtype);
if (issparse (x))
if (xsparse)
m = sparse (m);
endif
return;
Expand All @@ -296,7 +297,7 @@
return;
endif

## Permute dim to simplify all operations along dim1. At func. end ipermute.
## Permute dim to simplify all operations along dim1
if (numel (dim) > 1 || (dim != 1 && ! isvector (x)))
perm = 1 : ndx;

Expand Down Expand Up @@ -333,124 +334,207 @@
omitnan = false;
endif

x = sort (x, dim); # Note: pushes any NaN's to end for omitnan compatibility
## Sparse inputs use the original sort-based code path to preserve sparsity.
## Dense inputs use nth_element for O(n) selection instead of O(n log n) sort.

if (omitnan)
## Ignore any NaN's in data. Each operating vector might have a
## different number of non-NaN data points.
if (xsparse)
# use sort for sparse matrices to retain sparsity of output
x = sort (x, dim);

if (isvector (x))
## Checks above ensure either dim1 or dim2 vector
x = x(! isnan (x));
n = numel (x);
k = floor ((n + 1) / 2);
if (mod (n, 2))
## odd
m = x(k);
if (omitnan)
if (isvector (x))
x = x(! isnan (x));
n = numel (x);
k = floor ((n + 1) / 2);
if (n == 0)
m = sparse (NaN);
elseif (mod (n, 2))
m = sparse (x(k));
else
m = sparse ((x(k) + x(k + 1)) / 2);
endif
else
## even
m = (x(k) + x(k + 1)) / 2;
n = sum (! isnan (x), 1)(:);
k = floor ((n + 1) / 2);
odd_cols = mod (n, 2) & n;
even_cols = ! odd_cols & n;

m = sparse (NaN ([1, szx(2 : end)]));

if (ndims (x) > 2)
szx_flat = [szx(1), prod(szx(2 : end))];
else
szx_flat = szx;
endif

if (any (odd_cols))
idx = sub2ind (szx_flat, k(odd_cols), find (odd_cols));
m(odd_cols) = x(idx);
endif
if (any (even_cols))
k_even = k(even_cols);
idx = sub2ind (szx_flat, [k_even, k_even + 1], ...
find (even_cols)(:, [1, 1]));
m(even_cols) = sum (x(idx), 2) / 2;
endif
endif

else
## Each column may have a different n and k. Force index column vector
## for consistent orientation for 2-D and N-D inputs, then use sub2ind to
## get correct element(s) for each column.
## No "omitnan" for sparse
if (all (! nanfree))
m = NaN (sz_out);
m = sparse (m);

n = sum (! isnan (x), 1)(:);
k = floor ((n + 1) / 2);
m_idx_odd = mod (n, 2) & n;
m_idx_even = (! m_idx_odd) & n;
else
if (isvector (x))
n = numel (x);
k = floor ((n + 1) / 2);

m = x(k);
if (! mod (n, 2))
if (any (isinf ([x(k), x(k+1)])))
m = x(k) + x(k+1);
else
m += (x(k + 1) - m) / 2;
endif
endif
m = sparse (m);

m = NaN ([1, szx(2 : end)]);
if (issparse (x))
m = sparse (m);
endif
else
n = szx(1);
k = floor ((n + 1) / 2);

if (ndims (x) > 2)
szx = [szx(1), prod(szx(2 : end))];
endif
m = sparse (NaN ([1, szx(2 : end)]));

## Grab kth value, k possibly different for each column
if (any (m_idx_odd))
x_idx_odd = sub2ind (szx, k(m_idx_odd), find (m_idx_odd));
m(m_idx_odd) = x(x_idx_odd);
endif
if (any (m_idx_even))
k_even = k(m_idx_even);
x_idx_even = sub2ind (szx, [k_even, k_even + 1], ...
(find (m_idx_even))(:, [1, 1]));
m(m_idx_even) = sum (x(x_idx_even), 2) / 2;
if (! mod (n, 2))
m(nanfree) = (x(k, nanfree) + x(k + 1, nanfree)) / 2;
else
m(nanfree) = x(k, nanfree);
endif
endif
endif
endif

else
## No "omitnan". All 'vectors' uniform length.
## All types without a NaN value will use this path.
if (all (! nanfree))
m = NaN (sz_out);
if (issparse (x))
m = sparse (m);
endif

else
## dense: use nth_element for O(n) selection
if (omitnan)
## Ignore any NaN's in data.
## Each operating vector might have a different number of non-NaN data points.
if (isvector (x))
## Checks above ensure either dim1 or dim2 vector
x = x(! isnan (x));
n = numel (x);
k = floor ((n + 1) / 2);

m = x(k);
if (! mod (n, 2))
## Even
if (any (isinf ([x(k), x(k+1)])))
## If either center value is Inf, replace m by +/-Inf or NaN.
m = x(k) + x(k+1);
elseif (any (isa (x, "integer")))
## avoid int overflow issues
m2 = x(k + 1);
if (sign (m) != sign (m2))
m += m2;
m /= 2;
else
m += (m2 - m) / 2;
endif
if (n == 0)
m = NaN (sz_out, outtype);
else
k = floor ((n + 1) / 2);
if (mod (n, 2))
m = nth_element (x, k);
else
m += (x(k + 1) - m) / 2;
vals = nth_element (x, [k, k + 1]);
m = mid_two_vals (vals(1), vals(2), isa (x, "integer"));
endif
endif

else
## Nonvector, all operations were permuted to be along dim 1
## Columns may have different non-NaN counts; process individually.
n = szx(1);
k = floor ((n + 1) / 2);
rest_sz = szx(2 : end);
ncols = prod (rest_sz);

if (ndims (x) > 2)
x = reshape (x, [n, ncols]);
endif

if (isfloat (x))
m = NaN ([1, szx(2 : end)]);
if (issparse (x))
m = sparse (m);
endif
m = NaN (1, ncols);
else
m = zeros ([1, szx(2 : end)], outtype);
m = zeros (1, ncols, outtype);
endif

if (! mod (n, 2))
## Even
if (any (isa (x, "integer")))
## avoid int overflow issues
x = reshape (x, [n, ncols]);

## Use flattened index to simplify N-D operations
m(1, :) = x(k, :);
m2 = x(k + 1, :);
for j = 1:ncols
col = x(:, j);
col = col(! isnan (col));
ncol = numel (col);

samesign = prod (sign ([m(1, :); m2]), 1) == 1;
m(1, :) = samesign .* m(1, :) + ...
(m2 + !samesign .* m(1, :) - samesign .* m(1, :)) / 2;
if (ncol == 0)
continue;
endif

k = floor ((ncol + 1) / 2);
if (mod (ncol, 2))
m(j) = nth_element (col, k);
else
m(nanfree) = (x(k, nanfree) + x(k + 1, nanfree)) / 2;
vals = nth_element (col, [k, k + 1]);
m(j) = mid_two_vals (vals(1), vals(2), isa (x, "integer"));
endif
endfor

if (numel (rest_sz) > 1)
m = reshape (m, [1, rest_sz]);
endif
endif

else
## No "omitnan". All types without a NaN value will use this path.
if (all (! nanfree))
m = NaN (sz_out);

else
if (isvector (x))
n = numel (x);
k = floor ((n + 1) / 2);

if (! nanfree)
m = NaN (sz_out);
else
if (mod (n, 2))
## Odd
m = nth_element (x, k);
else
## Even
vals = nth_element (x, [k, k + 1]);
m = mid_two_vals (vals(1), vals(2), isa (x, "integer"));
endif
endif

else
## Odd. Use flattened index to simplify N-D operations
m(nanfree) = x(k, nanfree);
## Nonvector, all operations were permuted to be along dim 1
n = szx(1);
k = floor ((n + 1) / 2);
rest_sz = szx(2 : end);
ncols = prod (rest_sz);

if (isfloat (x))
m = NaN (1, ncols);
else
m = zeros (1, ncols, outtype);
endif

if (ndims (x) > 2)
x = reshape (x, [n, ncols]);
nanfree = reshape (nanfree, [1, ncols]);
endif

if (mod (n, 2))
## Odd. Use flattened index to simplify N-D operations
if (any (nanfree(:)))
vals = nth_element (x(:, nanfree), k, 1);
m(nanfree) = vals;
endif
else
## Even
if (any (nanfree(:)))
vals = nth_element (x(:, nanfree), [k, k + 1], 1);
m(nanfree) = mid_two_vals (vals(1, :), vals(2, :), isa (x, "integer"));
endif
endif

if (numel (rest_sz) > 1)
m = reshape (m, [1, rest_sz]);
endif
endif
endif
endif
Expand All @@ -469,6 +553,37 @@
endfunction


## Compute mean of two middle values, handling Inf and integer overflow.
function m = mid_two_vals (m1, m2, is_int)
if (is_int)
samesign = sign (m1) == sign (m2);
m = zeros (size (m1), "like", m1);
m(samesign) = m1(samesign) + (m2(samesign) - m1(samesign)) / 2;
m(! samesign) = (m1(! samesign) + m2(! samesign)) / 2;
else
m = (m1 + m2) / 2;
endif
endfunction


## Tests for per-element Inf handling in even-length median
%!test
%! x = [-Inf, 2; 1, 3];
%! assert (median (x, 1), [-Inf, 2.5]);

%!test
%! x = [-Inf, Inf; Inf, -Inf];
%! assert (median (x, 1), [NaN, NaN]);

%!test
%! x = [1, Inf, 3; 5, 7, 9];
%! assert (median (x, 1), [3, Inf, 6]);

%!test
%! x = int64 ([10, 20; 30, 40]);
%! assert (median (x, 1), int64 ([20, 30]));


%!assert (median (1), 1)
%!assert (median ([1, 2, 3]), 2)
%!assert (median ([1, 2, 3]'), 2)
Expand Down