Skip to content

Commit 539d832

Browse files
authored
add median op (#2705)
1 parent c4767d1 commit 539d832

File tree

5 files changed

+164
-0
lines changed

5 files changed

+164
-0
lines changed

docs/src/python/ops.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ Operations
112112
max
113113
maximum
114114
mean
115+
median
115116
meshgrid
116117
min
117118
minimum

mlx/ops.cpp

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1932,6 +1932,87 @@ array mean(
19321932
return mean(a, std::vector<int>{axis}, keepdims, to_stream(s));
19331933
}
19341934

1935+
array median(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) {
1936+
std::vector<int> axes(a.ndim());
1937+
std::iota(axes.begin(), axes.end(), 0);
1938+
return median(a, axes, keepdims, to_stream(s));
1939+
}
1940+
1941+
array median(
1942+
const array& a,
1943+
const std::vector<int>& axes,
1944+
bool keepdims /* = false */,
1945+
StreamOrDevice s /* = {}*/) {
1946+
int ndim = a.ndim();
1947+
std::set<int> set_axes;
1948+
for (int axis : axes) {
1949+
if (axis < -ndim || axis >= ndim) {
1950+
std::ostringstream msg;
1951+
msg << "[median] axis " << axis << " is out of bounds for array with "
1952+
<< ndim << " dimensions.";
1953+
throw std::invalid_argument(msg.str());
1954+
}
1955+
set_axes.insert(axis < 0 ? axis + ndim : axis);
1956+
}
1957+
if (set_axes.size() != axes.size()) {
1958+
throw std::invalid_argument("[median] Received duplicate axis.");
1959+
}
1960+
std::vector<int> sorted_axes(set_axes.begin(), set_axes.end());
1961+
auto dtype = at_least_float(a.dtype());
1962+
std::vector<int> transpose_axes;
1963+
for (int i = 0, j = 0; i < a.ndim(); ++i) {
1964+
if (j < sorted_axes.size() && i == sorted_axes[j]) {
1965+
j++;
1966+
continue;
1967+
}
1968+
transpose_axes.push_back(i);
1969+
}
1970+
int flat_start = transpose_axes.size();
1971+
transpose_axes.insert(
1972+
transpose_axes.end(), sorted_axes.begin(), sorted_axes.end());
1973+
1974+
// Move all the median axes to the back and flatten
1975+
auto flat_a =
1976+
flatten(transpose(a, transpose_axes, s), flat_start, a.ndim(), s);
1977+
int flat_size = flat_a.shape(-1);
1978+
if (flat_size == 0) {
1979+
throw std::invalid_argument(
1980+
"[median] Cannot take median along empty axis.");
1981+
}
1982+
1983+
// Sort the last axis
1984+
auto sorted_a = sort(flat_a, -1, s);
1985+
1986+
// Take the midpoint
1987+
auto mp = flat_size / 2;
1988+
auto start = Shape(sorted_a.ndim(), 0);
1989+
auto stop = sorted_a.shape();
1990+
start.back() = mp;
1991+
stop.back() = mp + 1;
1992+
auto median_a = astype(slice(sorted_a, start, stop, s), dtype, s);
1993+
if (flat_size % 2 == 0) {
1994+
start.back() = mp - 1;
1995+
stop.back() = mp;
1996+
median_a = multiply(
1997+
add(median_a, astype(slice(sorted_a, start, stop, s), dtype, s), s),
1998+
array(0.5, dtype),
1999+
s);
2000+
}
2001+
median_a = squeeze(median_a, -1, s);
2002+
if (keepdims) {
2003+
median_a = expand_dims(median_a, sorted_axes, s);
2004+
}
2005+
return median_a;
2006+
}
2007+
2008+
array median(
2009+
const array& a,
2010+
int axis,
2011+
bool keepdims /* = false */,
2012+
StreamOrDevice s /* = {} */) {
2013+
return median(a, std::vector<int>{axis}, keepdims, to_stream(s));
2014+
}
2015+
19352016
array var(
19362017
const array& a,
19372018
bool keepdims,

mlx/ops.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,26 @@ array mean(
539539
bool keepdims = false,
540540
StreamOrDevice s = {});
541541

542+
/** Computes the median of the elements of an array. */
543+
array median(const array& a, bool keepdims, StreamOrDevice s = {});
544+
inline array median(const array& a, StreamOrDevice s = {}) {
545+
return median(a, false, to_stream(s));
546+
}
547+
548+
/** Computes the median of the elements of an array along the given axes */
549+
array median(
550+
const array& a,
551+
const std::vector<int>& axes,
552+
bool keepdims = false,
553+
StreamOrDevice s = {});
554+
555+
/** Computes the median of the elements of an array along the given axis */
556+
array median(
557+
const array& a,
558+
int axis,
559+
bool keepdims = false,
560+
StreamOrDevice s = {});
561+
542562
/** Computes the variance of the elements of an array. */
543563
array var(const array& a, bool keepdims, int ddof = 0, StreamOrDevice s = {});
544564
inline array var(const array& a, StreamOrDevice s = {}) {

python/src/ops.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2484,6 +2484,35 @@ void init_ops(nb::module_& m) {
24842484
Returns:
24852485
array: The output array of means.
24862486
)pbdoc");
2487+
m.def(
2488+
"median",
2489+
[](const mx::array& a,
2490+
const IntOrVec& axis,
2491+
bool keepdims,
2492+
mx::StreamOrDevice s) {
2493+
return mx::median(a, get_reduce_axes(axis, a.ndim()), keepdims, s);
2494+
},
2495+
nb::arg(),
2496+
"axis"_a = nb::none(),
2497+
"keepdims"_a = false,
2498+
nb::kw_only(),
2499+
"stream"_a = nb::none(),
2500+
nb::sig(
2501+
"def median(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
2502+
R"pbdoc(
2503+
Compute the median(s) over the given axes.
2504+
2505+
Args:
2506+
a (array): Input array.
2507+
axis (int or list(int), optional): Optional axis or
2508+
axes to reduce over. If unspecified this defaults
2509+
to reducing over the entire array.
2510+
keepdims (bool, optional): Keep reduced axes as
2511+
singleton dimensions, defaults to `False`.
2512+
2513+
Returns:
2514+
array: The output array of medians.
2515+
)pbdoc");
24872516
m.def(
24882517
"var",
24892518
[](const mx::array& a,

python/tests/test_ops.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -775,6 +775,39 @@ def test_mean(self):
775775
self.assertEqual(mx.mean(x, axis=0).tolist(), [2, 3])
776776
self.assertEqual(mx.mean(x, axis=1).tolist(), [1.5, 3.5])
777777

778+
def test_median(self):
779+
x = mx.array([])
780+
with self.assertRaises(ValueError):
781+
mx.median(x, axis=0)
782+
x = mx.array([0, 1, 2, 3, 4])
783+
with self.assertRaises(ValueError):
784+
mx.median(x, axis=(0, 1))
785+
with self.assertRaises(ValueError):
786+
mx.median(x, axis=(0, 0))
787+
788+
out = mx.median(x)
789+
self.assertEqual(out.shape, ())
790+
self.assertEqual(out.item(), 2)
791+
out = mx.median(x, keepdims=True)
792+
self.assertEqual(out.shape, (1,))
793+
794+
x = mx.array([0, 1, 2, 3, 4, 5])
795+
out = mx.median(x)
796+
self.assertEqual(out.item(), 2.5)
797+
798+
x = mx.random.normal((5, 5, 5, 5))
799+
out = mx.median(x, axis=(0, 2), keepdims=True)
800+
out_np = np.median(x, axis=(0, 2), keepdims=True)
801+
self.assertTrue(np.allclose(out, out_np))
802+
803+
out = mx.median(x, axis=(1, 3), keepdims=True)
804+
out_np = np.median(x, axis=(1, 3), keepdims=True)
805+
self.assertTrue(np.allclose(out, out_np))
806+
807+
out = mx.median(x, axis=(0, 1, 3), keepdims=True)
808+
out_np = np.median(x, axis=(0, 1, 3), keepdims=True)
809+
self.assertTrue(np.allclose(out, out_np))
810+
778811
def test_var(self):
779812
x = mx.array(
780813
[

0 commit comments

Comments
 (0)