Skip to content

Commit fb6244b

Browse files
committed
Add MLX where.
1 parent 0ccd6bc commit fb6244b

File tree

2 files changed

+1
-10
lines changed

2 files changed

+1
-10
lines changed

keras/backend/mlx/numpy.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -889,9 +889,7 @@ def vstack(xs):
889889

890890

891891
def where(condition, x1, x2):
892-
# TODO: Trivial to implement with masking but it would be incorrect for
893-
# instance in the presence of nans or infs
894-
raise NotImplementedError("The MLX backend doesn't support where yet")
892+
return mx.where(condition, x1, x2)
895893

896894

897895
def divide(x1, x2):

keras/backend/mlx/rnn.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
import mlx.core as mx
2-
3-
41
def rnn(
52
step_function,
63
inputs,
@@ -27,7 +24,3 @@ def lstm(*args, **kwargs):
2724

2825
def gru(*args, **kwargs):
2926
raise NotImplementedError("gru not yet implemented in mlx")
30-
31-
32-
def unstack(x, axis=0):
33-
return mx.split(x, axis=axis)

0 commit comments

Comments
 (0)