Skip to content

Commit 1edc652

Browse files
authored
Implemented median(...) function. (#19568)
* implemented median function * uncommented convert_to_tensor() function
1 parent 587b4c9 commit 1edc652

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

keras/src/backend/mlx/numpy.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -579,9 +579,16 @@ def maximum(x1, x2):
579579
return mx.maximum(x1, x2)
580580

581581

582-
def median(x, axis=None, keepdims=False):
583-
# TODO: Maybe implement via sort?
584-
raise NotImplementedError("The MLX backend doesn't support median yet")
582+
def median(x, axis=-1, keepdims=False):
583+
x = convert_to_tensor(x)
584+
x_sorted = mx.sort(x, axis=axis)
585+
axis_size = x_sorted.shape[axis]
586+
medians = mx.take(
587+
x_sorted, indices=mx.array([(axis_size // 2) - 1]), axis=axis
588+
)
589+
if not keepdims:
590+
medians = mx.squeeze(medians, axis=axis)
591+
return medians
585592

586593

587594
def meshgrid(*x, indexing="xy"):

0 commit comments

Comments
 (0)