Skip to content

Commit 7237ec7

Browse files
authored
mlx - passing most layers tests (#20862)
* adding ops functions and passing most layer tests * passing most layers tests * fix for tensorflow tests, handling mlx array slicing in random crop
1 parent b9f7141 commit 7237ec7

File tree

19 files changed

+984
-358
lines changed

19 files changed

+984
-358
lines changed

keras/src/backend/mlx/core.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
"int64": mx.int64,
3737
"bfloat16": mx.bfloat16,
3838
"bool": mx.bool_,
39+
"complex64": mx.complex64,
3940
}
4041

4142

@@ -376,9 +377,9 @@ def random_seed_dtype():
376377
return "uint32"
377378

378379

379-
def reverse_sequence(xs):
380-
indices = mx.arange(xs.shape[0] - 1, -1, -1)
381-
return mx.take(xs, indices, axis=0)
380+
def reverse_sequence(xs, axis=0):
381+
indices = mx.arange(xs.shape[axis] - 1, -1, -1)
382+
return mx.take(xs, indices, axis=axis)
382383

383384

384385
def flip(x, axis=None):

0 commit comments

Comments
 (0)