@@ -1932,6 +1932,87 @@ array mean(
19321932 return mean (a, std::vector<int >{axis}, keepdims, to_stream (s));
19331933}
19341934
1935+ array median (const array& a, bool keepdims, StreamOrDevice s /* = {}*/ ) {
1936+ std::vector<int > axes (a.ndim ());
1937+ std::iota (axes.begin (), axes.end (), 0 );
1938+ return median (a, axes, keepdims, to_stream (s));
1939+ }
1940+
1941+ array median (
1942+ const array& a,
1943+ const std::vector<int >& axes,
1944+ bool keepdims /* = false */ ,
1945+ StreamOrDevice s /* = {}*/ ) {
1946+ int ndim = a.ndim ();
1947+ std::set<int > set_axes;
1948+ for (int axis : axes) {
1949+ if (axis < -ndim || axis >= ndim) {
1950+ std::ostringstream msg;
1951+ msg << " [median] axis " << axis << " is out of bounds for array with "
1952+ << ndim << " dimensions." ;
1953+ throw std::invalid_argument (msg.str ());
1954+ }
1955+ set_axes.insert (axis < 0 ? axis + ndim : axis);
1956+ }
1957+ if (set_axes.size () != axes.size ()) {
1958+ throw std::invalid_argument (" [median] Received duplicate axis." );
1959+ }
1960+ std::vector<int > sorted_axes (set_axes.begin (), set_axes.end ());
1961+ auto dtype = at_least_float (a.dtype ());
1962+ std::vector<int > transpose_axes;
1963+ for (int i = 0 , j = 0 ; i < a.ndim (); ++i) {
1964+ if (j < sorted_axes.size () && i == sorted_axes[j]) {
1965+ j++;
1966+ continue ;
1967+ }
1968+ transpose_axes.push_back (i);
1969+ }
1970+ int flat_start = transpose_axes.size ();
1971+ transpose_axes.insert (
1972+ transpose_axes.end (), sorted_axes.begin (), sorted_axes.end ());
1973+
1974+ // Move all the median axes to the back and flatten
1975+ auto flat_a =
1976+ flatten (transpose (a, transpose_axes, s), flat_start, a.ndim (), s);
1977+ int flat_size = flat_a.shape (-1 );
1978+ if (flat_size == 0 ) {
1979+ throw std::invalid_argument (
1980+ " [median] Cannot take median along empty axis." );
1981+ }
1982+
1983+ // Sort the last axis
1984+ auto sorted_a = sort (flat_a, -1 , s);
1985+
1986+ // Take the midpoint
1987+ auto mp = flat_size / 2 ;
1988+ auto start = Shape (sorted_a.ndim (), 0 );
1989+ auto stop = sorted_a.shape ();
1990+ start.back () = mp;
1991+ stop.back () = mp + 1 ;
1992+ auto median_a = astype (slice (sorted_a, start, stop, s), dtype, s);
1993+ if (flat_size % 2 == 0 ) {
1994+ start.back () = mp - 1 ;
1995+ stop.back () = mp;
1996+ median_a = multiply (
1997+ add (median_a, astype (slice (sorted_a, start, stop, s), dtype, s), s),
1998+ array (0.5 , dtype),
1999+ s);
2000+ }
2001+ median_a = squeeze (median_a, -1 , s);
2002+ if (keepdims) {
2003+ median_a = expand_dims (median_a, sorted_axes, s);
2004+ }
2005+ return median_a;
2006+ }
2007+
2008+ array median (
2009+ const array& a,
2010+ int axis,
2011+ bool keepdims /* = false */ ,
2012+ StreamOrDevice s /* = {} */ ) {
2013+ return median (a, std::vector<int >{axis}, keepdims, to_stream (s));
2014+ }
2015+
19352016array var (
19362017 const array& a,
19372018 bool keepdims,
0 commit comments