We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
median(...)
1 parent 587b4c9 commit 1edc652Copy full SHA for 1edc652
keras/src/backend/mlx/numpy.py
@@ -579,9 +579,16 @@ def maximum(x1, x2):
579
return mx.maximum(x1, x2)
580
581
582
-def median(x, axis=None, keepdims=False):
583
- # TODO: Maybe implement via sort?
584
- raise NotImplementedError("The MLX backend doesn't support median yet")
+def median(x, axis=-1, keepdims=False):
+ x = convert_to_tensor(x)
+ x_sorted = mx.sort(x, axis=axis)
585
+ axis_size = x_sorted.shape[axis]
586
+ medians = mx.take(
587
+ x_sorted, indices=mx.array([(axis_size // 2) - 1]), axis=axis
588
+ )
589
+ if not keepdims:
590
+ medians = mx.squeeze(medians, axis=axis)
591
+ return medians
592
593
594
def meshgrid(*x, indexing="xy"):
0 commit comments