1
+ import builtins
2
+ import functools
3
+ import warnings
4
+
1
5
import mlx .core as mx
2
6
import numpy as np
3
7
4
8
from keras .src import tree
5
9
from keras .src .backend .common import KerasVariable
6
10
from keras .src .backend .common import standardize_dtype
11
+ from keras .src .backend .common .backend_utils import slice_along_axis
7
12
from keras .src .backend .common .keras_tensor import KerasTensor
8
13
from keras .src .backend .common .stateless_scope import StatelessScope
14
+ from keras .src .backend .common .symbolic_scope import SymbolicScope
9
15
10
16
try :
11
17
import h5py
12
18
except ImportError :
13
19
h5py = None
14
20
15
21
SUPPORTS_SPARSE_TENSORS = False
22
+ SUPPORTS_RAGGED_TENSORS = False
23
+ IS_THREAD_SAFE = True
16
24
17
25
MLX_DTYPES = {
18
26
"float16" : mx .float16 ,
@@ -67,9 +75,11 @@ def _is_h5py_dataset(obj):
67
75
)
68
76
69
77
70
- def convert_to_tensor (x , dtype = None , sparse = None ):
78
+ def convert_to_tensor (x , dtype = None , sparse = None , ragged = None ):
71
79
if sparse :
72
80
raise ValueError ("`sparse=True` is not supported with mlx backend" )
81
+ if ragged :
82
+ raise ValueError ("`ragged=True` is not supported with mlx backend" )
73
83
mlx_dtype = to_mlx_dtype (dtype ) if dtype is not None else None
74
84
75
85
if is_tensor (x ):
@@ -83,7 +93,13 @@ def convert_to_tensor(x, dtype=None, sparse=None):
83
93
return x .value
84
94
85
95
if isinstance (x , np .ndarray ):
86
- x = x .astype (standardize_dtype (x .dtype ))
96
+ if x .dtype == np .float64 :
97
+ # mlx backend does not support float64
98
+ x = x .astype (np .float32 )
99
+ if standardize_dtype (x .dtype ) == "bfloat16" and mlx_dtype is None :
100
+ # if a bfloat16 np.ndarray is passed to mx.array with dtype=None
101
+ # it casts the output to complex64, so we force cast to bfloat16
102
+ mlx_dtype = mx .bfloat16
87
103
return mx .array (x , dtype = mlx_dtype )
88
104
89
105
if isinstance (x , list ):
@@ -191,10 +207,12 @@ def symbolic_call(fn, args, kwargs, fill_value):
191
207
)
192
208
return fn (* arr_args , ** arr_kwargs )
193
209
194
- with StatelessScope ():
210
+ with StatelessScope (), SymbolicScope () :
195
211
outputs = symbolic_call (fn , args , kwargs , fill_value = 83 )
196
212
197
- none_in_shape = any (map (has_none_shape , tree .flatten ((args , kwargs ))))
213
+ none_in_shape = any (
214
+ builtins .map (has_none_shape , tree .flatten ((args , kwargs )))
215
+ )
198
216
if none_in_shape :
199
217
outputs_1 = outputs
200
218
outputs_2 = symbolic_call (fn , args , kwargs , fill_value = 89 )
@@ -293,6 +311,13 @@ def slice_update(inputs, start_indices, updates):
293
311
return inputs
294
312
295
313
314
+ def switch (index , branches , * operands ):
315
+ index = convert_to_tensor (index , "int32" )
316
+ index = mx .clip (index , 0 , len (branches ) - 1 ).tolist ()
317
+ operands = tuple (convert_to_tensor (o ) for o in operands )
318
+ return branches [index ](* operands )
319
+
320
+
296
321
def while_loop (
297
322
cond ,
298
323
body ,
@@ -336,6 +361,8 @@ def fori_loop(lower, upper, body_fun, init_val):
336
361
337
362
338
363
def stop_gradient (variable ):
364
+ if isinstance (variable , Variable ):
365
+ variable = variable .value
339
366
return mx .stop_gradient (variable )
340
367
341
368
@@ -344,55 +371,223 @@ def unstack(x, num=None, axis=0):
344
371
return [yi .squeeze (axis ) for yi in y ]
345
372
346
373
374
+ def random_seed_dtype ():
375
+ # mlx random seed uses uint32.
376
+ return "uint32"
377
+
378
+
347
379
def reverse_sequence (xs ):
348
380
indices = mx .arange (xs .shape [0 ] - 1 , - 1 , - 1 )
349
381
return mx .take (xs , indices , axis = 0 )
350
382
351
383
352
- def scan (f , init , xs , reverse = False , mask = None ):
353
- states = init
354
- outputs_list = []
384
+ def flip (x , axis = None ):
385
+ if axis is None :
386
+ # flip all axes
387
+ axes = range (x .ndim )
388
+ else :
389
+ axes = [axis ] if isinstance (axis , int ) else axis
390
+
391
+ for axis in axes :
392
+ indices = mx .arange (x .shape [axis ] - 1 , - 1 , - 1 )
393
+ x = mx .take (x , indices , axis = axis )
394
+
395
+ return x
396
+
397
+
398
+ def scan (f , init , xs = None , length = None , reverse = False , unroll = 1 ):
399
+ # Ref: jax.lax.scan
400
+ if not callable (f ):
401
+ raise TypeError (f"`f` should be a callable. Received: f={ f } " )
402
+ if not isinstance (unroll , bool ):
403
+ if not isinstance (unroll , int ) or unroll < 1 :
404
+ raise ValueError (
405
+ "`unroll` must be an positive integer or boolean. "
406
+ f"Received: unroll={ unroll } "
407
+ )
408
+ if xs is None and length is None :
409
+ raise ValueError ("Got no `xs` to scan over and `length` not provided." )
410
+
411
+ input_is_sequence = tree .is_nested (xs )
412
+ output_is_sequence = tree .is_nested (init )
413
+
414
+ def pack_input (x ):
415
+ return tree .pack_sequence_as (xs , x ) if input_is_sequence else x [0 ]
416
+
417
+ def pack_output (x ):
418
+ return tree .pack_sequence_as (init , x ) if output_is_sequence else x [0 ]
355
419
356
- if mask is not None :
357
- x , mask = xs
358
- if reverse :
359
- x = reverse_sequence (x )
360
- mask = reverse_sequence (mask )
361
- iterator = zip (x , mask )
420
+ if xs is None :
421
+ xs_flat = []
422
+ n = int (length )
362
423
else :
363
- if reverse :
364
- if isinstance (xs , tuple ):
365
- xs = tuple (reverse_sequence (x ) for x in xs )
366
- else :
367
- xs = reverse_sequence (xs )
368
- iterator = zip (* xs ) if isinstance (xs , tuple ) else xs
369
-
370
- for x in iterator :
371
- result = f (states , x )
372
- if isinstance (result , tuple ):
373
- states , outputs = result
374
- if outputs is not None :
375
- outputs_list .append (outputs )
376
- else :
377
- states = result
378
-
379
- if outputs_list :
380
- if isinstance (outputs_list [0 ], tuple ):
381
- # Multiple outputs case
382
- outputs = tuple (
383
- mx .stack ([out [i ] for out in outputs_list ])
384
- for i in range (len (outputs_list [0 ]))
424
+ xs_flat = tree .flatten (xs )
425
+ xs_flat = [convert_to_tensor (elem ) for elem in xs_flat ]
426
+ n = int (length ) if length is not None else shape (xs_flat [0 ])[0 ]
427
+
428
+ init_flat = tree .flatten (init )
429
+ init_flat = [convert_to_tensor (init ) for init in init_flat ]
430
+ init = pack_output (init_flat )
431
+ dummy_y = [mx .zeros_like (init ) for init in init_flat ]
432
+
433
+ carry = init
434
+ ys = []
435
+ maybe_reversed = reversed if reverse else lambda x : x
436
+ for i in maybe_reversed (range (n )):
437
+ xs_slice = [x [i ] for x in xs_flat ]
438
+ packed_xs = pack_input (xs_slice ) if len (xs_slice ) > 0 else None
439
+ carry , y = f (carry , packed_xs )
440
+ ys .append (y if y is not None else dummy_y )
441
+ stacked_y = tree .map_structure (
442
+ lambda * ys : mx .stack (ys ), * maybe_reversed (ys )
443
+ )
444
+ return carry , stacked_y
445
+
446
+
447
+ def map (f , xs ):
448
+ def g (_ , x ):
449
+ return (), f (x )
450
+
451
+ _ , ys = scan (g , (), xs )
452
+ return ys
453
+
454
+
455
+ def dilate (x , axis , dilation_rate ):
456
+ x_shape = list (x .shape )
457
+ x_shape [axis ] = x .shape [axis ] * dilation_rate - 1
458
+
459
+ result = mx .zeros (x_shape , dtype = x .dtype )
460
+
461
+ if axis >= 0 :
462
+ slices = [builtins .slice (None )] * axis + [
463
+ builtins .slice (0 , None , dilation_rate )
464
+ ]
465
+ else :
466
+ slices = [Ellipsis , builtins .slice (0 , None , dilation_rate )] + [
467
+ builtins .slice (None )
468
+ ] * (- 1 - axis )
469
+ result [tuple (slices )] = x
470
+
471
+ return result
472
+
473
+
474
+ def associative_scan (f , elems , reverse = False , axis = 0 ):
475
+ # Ref: jax.lax.associative_scan
476
+ if not callable (f ):
477
+ raise TypeError (f"`f` should be a callable. Received: f={ f } " )
478
+ elems_flat = tree .flatten (elems )
479
+ elems_flat = [convert_to_tensor (elem ) for elem in elems_flat ]
480
+ if reverse :
481
+ elems_flat = [flip (elem , (axis ,)) for elem in elems_flat ]
482
+
483
+ def _combine (a_flat , b_flat ):
484
+ a = tree .pack_sequence_as (elems , a_flat )
485
+ b = tree .pack_sequence_as (elems , b_flat )
486
+ c = f (a , b )
487
+ c_flat = tree .flatten (c )
488
+ return c_flat
489
+
490
+ num_elems = int (elems_flat [0 ].shape [axis ])
491
+ if not all (int (elem .shape [axis ]) == num_elems for elem in elems_flat [1 :]):
492
+ raise ValueError (
493
+ "Array inputs to associative_scan must have the same "
494
+ "first dimension. (saw: {})" .format (
495
+ [elem .shape for elem in elems_flat ]
496
+ )
497
+ )
498
+
499
+ def _interleave (a , b , axis ):
500
+ """Given two Tensors of static shape, interleave them along axis."""
501
+ assert (
502
+ a .shape [axis ] == b .shape [axis ] or a .shape [axis ] == b .shape [axis ] + 1
503
+ )
504
+
505
+ # we want to get a: [a1, a2], b: [b1, b2]
506
+ # to a: [a1, 0, a2, 0], b: [0, b1, 0, b2]
507
+ a_dil = dilate (a , axis , 2 )
508
+ b_dil = dilate (b , axis , 2 )
509
+
510
+ a_pad = [[0 , 0 ] for _ in range (a .ndim )]
511
+ a_pad [axis ][- 1 ] = 1 if a .shape [axis ] == b .shape [axis ] else 0
512
+
513
+ b_pad = [[0 , 0 ] for _ in range (b .ndim )]
514
+ b_pad [axis ] = [1 , 0 ] if a .shape [axis ] == b .shape [axis ] else [1 , 1 ]
515
+
516
+ op = mx .bitwise_or if a .dtype == mx .bool_ else mx .add
517
+ return op (
518
+ mx .pad (a_dil , a_pad ),
519
+ mx .pad (b_dil , b_pad ),
520
+ )
521
+
522
+ def _scan (elems ):
523
+ num_elems = elems [0 ].shape [axis ]
524
+ if num_elems < 2 :
525
+ return elems
526
+
527
+ reduced_elems = _combine (
528
+ [
529
+ slice_along_axis (elem , 0 , - 1 , step = 2 , axis = axis )
530
+ for elem in elems
531
+ ],
532
+ [
533
+ slice_along_axis (elem , 1 , None , step = 2 , axis = axis )
534
+ for elem in elems
535
+ ],
536
+ )
537
+
538
+ odd_elems = _scan (reduced_elems )
539
+ if num_elems % 2 == 0 :
540
+ even_elems = _combine (
541
+ [slice_along_axis (e , 0 , - 1 , axis = axis ) for e in odd_elems ],
542
+ [
543
+ slice_along_axis (e , 2 , None , step = 2 , axis = axis )
544
+ for e in elems
545
+ ],
385
546
)
386
547
else :
387
- # Single output case
388
- outputs = mx .stack (outputs_list )
548
+ even_elems = _combine (
549
+ odd_elems ,
550
+ [
551
+ slice_along_axis (e , 2 , None , step = 2 , axis = axis )
552
+ for e in elems
553
+ ],
554
+ )
389
555
390
- if reverse :
391
- if isinstance (outputs , tuple ):
392
- outputs = tuple (reverse_sequence (out ) for out in outputs )
393
- else :
394
- outputs = reverse_sequence (outputs )
556
+ even_elems = [
557
+ mx .concatenate (
558
+ [slice_along_axis (elem , 0 , 1 , axis = axis ), result ],
559
+ axis = axis ,
560
+ )
561
+ for (elem , result ) in zip (elems , even_elems )
562
+ ]
563
+ return list (
564
+ builtins .map (
565
+ functools .partial (_interleave , axis = axis ), even_elems , odd_elems
566
+ )
567
+ )
395
568
396
- return states , outputs
569
+ scans = _scan (elems_flat )
570
+ if reverse :
571
+ scans = [flip (scanned , (axis ,)) for scanned in scans ]
572
+
573
+ return tree .pack_sequence_as (elems , scans )
574
+
575
+
576
+ class custom_gradient :
577
+ """Decorator for custom gradients.
578
+
579
+ Args:
580
+ fun: Forward pass function.
581
+ """
582
+
583
+ def __init__ (self , fun ):
584
+ warnings .warn (
585
+ "`custom_gradient` for the mlx backend acts as a pass-through to "
586
+ "support the forward pass. No gradient computation or modification "
587
+ "takes place."
588
+ )
589
+ self .fun = fun
397
590
398
- return states , None
591
+ def __call__ (self , * args , ** kwargs ):
592
+ outputs , _ = self .fun (* args , ** kwargs )
593
+ return outputs
0 commit comments