Skip to content

Commit ce9aec7

Browse files
authored
mlx - nn and core updates post merge (#20832)
* core, nn, and rnn passing tests * formatting
1 parent 168f13d commit ce9aec7

File tree

8 files changed

+455
-66
lines changed

8 files changed

+455
-66
lines changed

keras/src/backend/mlx/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
"""MLX backend APIs."""
22

33
from keras.src.backend.common.name_scope import name_scope
4+
from keras.src.backend.jax.core import random_seed_dtype
45
from keras.src.backend.mlx import core
56
from keras.src.backend.mlx import image
67
from keras.src.backend.mlx import linalg
78
from keras.src.backend.mlx import math
89
from keras.src.backend.mlx import nn
910
from keras.src.backend.mlx import numpy
1011
from keras.src.backend.mlx import random
12+
from keras.src.backend.mlx.core import IS_THREAD_SAFE
13+
from keras.src.backend.mlx.core import SUPPORTS_RAGGED_TENSORS
1114
from keras.src.backend.mlx.core import SUPPORTS_SPARSE_TENSORS
1215
from keras.src.backend.mlx.core import Variable
1316
from keras.src.backend.mlx.core import cast

keras/src/backend/mlx/core.py

Lines changed: 239 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,26 @@
1+
import builtins
2+
import functools
3+
import warnings
4+
15
import mlx.core as mx
26
import numpy as np
37

48
from keras.src import tree
59
from keras.src.backend.common import KerasVariable
610
from keras.src.backend.common import standardize_dtype
11+
from keras.src.backend.common.backend_utils import slice_along_axis
712
from keras.src.backend.common.keras_tensor import KerasTensor
813
from keras.src.backend.common.stateless_scope import StatelessScope
14+
from keras.src.backend.common.symbolic_scope import SymbolicScope
915

1016
try:
1117
import h5py
1218
except ImportError:
1319
h5py = None
1420

1521
SUPPORTS_SPARSE_TENSORS = False
22+
SUPPORTS_RAGGED_TENSORS = False
23+
IS_THREAD_SAFE = True
1624

1725
MLX_DTYPES = {
1826
"float16": mx.float16,
@@ -67,9 +75,11 @@ def _is_h5py_dataset(obj):
6775
)
6876

6977

70-
def convert_to_tensor(x, dtype=None, sparse=None):
78+
def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
7179
if sparse:
7280
raise ValueError("`sparse=True` is not supported with mlx backend")
81+
if ragged:
82+
raise ValueError("`ragged=True` is not supported with mlx backend")
7383
mlx_dtype = to_mlx_dtype(dtype) if dtype is not None else None
7484

7585
if is_tensor(x):
@@ -83,7 +93,13 @@ def convert_to_tensor(x, dtype=None, sparse=None):
8393
return x.value
8494

8595
if isinstance(x, np.ndarray):
86-
x = x.astype(standardize_dtype(x.dtype))
96+
if x.dtype == np.float64:
97+
# mlx backend does not support float64
98+
x = x.astype(np.float32)
99+
if standardize_dtype(x.dtype) == "bfloat16" and mlx_dtype is None:
100+
# if a bfloat16 np.ndarray is passed to mx.array with dtype=None
101+
# it casts the output to complex64, so we force cast to bfloat16
102+
mlx_dtype = mx.bfloat16
87103
return mx.array(x, dtype=mlx_dtype)
88104

89105
if isinstance(x, list):
@@ -191,10 +207,12 @@ def symbolic_call(fn, args, kwargs, fill_value):
191207
)
192208
return fn(*arr_args, **arr_kwargs)
193209

194-
with StatelessScope():
210+
with StatelessScope(), SymbolicScope():
195211
outputs = symbolic_call(fn, args, kwargs, fill_value=83)
196212

197-
none_in_shape = any(map(has_none_shape, tree.flatten((args, kwargs))))
213+
none_in_shape = any(
214+
builtins.map(has_none_shape, tree.flatten((args, kwargs)))
215+
)
198216
if none_in_shape:
199217
outputs_1 = outputs
200218
outputs_2 = symbolic_call(fn, args, kwargs, fill_value=89)
@@ -293,6 +311,13 @@ def slice_update(inputs, start_indices, updates):
293311
return inputs
294312

295313

314+
def switch(index, branches, *operands):
315+
index = convert_to_tensor(index, "int32")
316+
index = mx.clip(index, 0, len(branches) - 1).tolist()
317+
operands = tuple(convert_to_tensor(o) for o in operands)
318+
return branches[index](*operands)
319+
320+
296321
def while_loop(
297322
cond,
298323
body,
@@ -336,6 +361,8 @@ def fori_loop(lower, upper, body_fun, init_val):
336361

337362

338363
def stop_gradient(variable):
364+
if isinstance(variable, Variable):
365+
variable = variable.value
339366
return mx.stop_gradient(variable)
340367

341368

@@ -344,55 +371,223 @@ def unstack(x, num=None, axis=0):
344371
return [yi.squeeze(axis) for yi in y]
345372

346373

374+
def random_seed_dtype():
375+
# mlx random seed uses uint32.
376+
return "uint32"
377+
378+
347379
def reverse_sequence(xs):
348380
indices = mx.arange(xs.shape[0] - 1, -1, -1)
349381
return mx.take(xs, indices, axis=0)
350382

351383

352-
def scan(f, init, xs, reverse=False, mask=None):
353-
states = init
354-
outputs_list = []
384+
def flip(x, axis=None):
385+
if axis is None:
386+
# flip all axes
387+
axes = range(x.ndim)
388+
else:
389+
axes = [axis] if isinstance(axis, int) else axis
390+
391+
for axis in axes:
392+
indices = mx.arange(x.shape[axis] - 1, -1, -1)
393+
x = mx.take(x, indices, axis=axis)
394+
395+
return x
396+
397+
398+
def scan(f, init, xs=None, length=None, reverse=False, unroll=1):
399+
# Ref: jax.lax.scan
400+
if not callable(f):
401+
raise TypeError(f"`f` should be a callable. Received: f={f}")
402+
if not isinstance(unroll, bool):
403+
if not isinstance(unroll, int) or unroll < 1:
404+
raise ValueError(
405+
"`unroll` must be an positive integer or boolean. "
406+
f"Received: unroll={unroll}"
407+
)
408+
if xs is None and length is None:
409+
raise ValueError("Got no `xs` to scan over and `length` not provided.")
410+
411+
input_is_sequence = tree.is_nested(xs)
412+
output_is_sequence = tree.is_nested(init)
413+
414+
def pack_input(x):
415+
return tree.pack_sequence_as(xs, x) if input_is_sequence else x[0]
416+
417+
def pack_output(x):
418+
return tree.pack_sequence_as(init, x) if output_is_sequence else x[0]
355419

356-
if mask is not None:
357-
x, mask = xs
358-
if reverse:
359-
x = reverse_sequence(x)
360-
mask = reverse_sequence(mask)
361-
iterator = zip(x, mask)
420+
if xs is None:
421+
xs_flat = []
422+
n = int(length)
362423
else:
363-
if reverse:
364-
if isinstance(xs, tuple):
365-
xs = tuple(reverse_sequence(x) for x in xs)
366-
else:
367-
xs = reverse_sequence(xs)
368-
iterator = zip(*xs) if isinstance(xs, tuple) else xs
369-
370-
for x in iterator:
371-
result = f(states, x)
372-
if isinstance(result, tuple):
373-
states, outputs = result
374-
if outputs is not None:
375-
outputs_list.append(outputs)
376-
else:
377-
states = result
378-
379-
if outputs_list:
380-
if isinstance(outputs_list[0], tuple):
381-
# Multiple outputs case
382-
outputs = tuple(
383-
mx.stack([out[i] for out in outputs_list])
384-
for i in range(len(outputs_list[0]))
424+
xs_flat = tree.flatten(xs)
425+
xs_flat = [convert_to_tensor(elem) for elem in xs_flat]
426+
n = int(length) if length is not None else shape(xs_flat[0])[0]
427+
428+
init_flat = tree.flatten(init)
429+
init_flat = [convert_to_tensor(init) for init in init_flat]
430+
init = pack_output(init_flat)
431+
dummy_y = [mx.zeros_like(init) for init in init_flat]
432+
433+
carry = init
434+
ys = []
435+
maybe_reversed = reversed if reverse else lambda x: x
436+
for i in maybe_reversed(range(n)):
437+
xs_slice = [x[i] for x in xs_flat]
438+
packed_xs = pack_input(xs_slice) if len(xs_slice) > 0 else None
439+
carry, y = f(carry, packed_xs)
440+
ys.append(y if y is not None else dummy_y)
441+
stacked_y = tree.map_structure(
442+
lambda *ys: mx.stack(ys), *maybe_reversed(ys)
443+
)
444+
return carry, stacked_y
445+
446+
447+
def map(f, xs):
448+
def g(_, x):
449+
return (), f(x)
450+
451+
_, ys = scan(g, (), xs)
452+
return ys
453+
454+
455+
def dilate(x, axis, dilation_rate):
456+
x_shape = list(x.shape)
457+
x_shape[axis] = x.shape[axis] * dilation_rate - 1
458+
459+
result = mx.zeros(x_shape, dtype=x.dtype)
460+
461+
if axis >= 0:
462+
slices = [builtins.slice(None)] * axis + [
463+
builtins.slice(0, None, dilation_rate)
464+
]
465+
else:
466+
slices = [Ellipsis, builtins.slice(0, None, dilation_rate)] + [
467+
builtins.slice(None)
468+
] * (-1 - axis)
469+
result[tuple(slices)] = x
470+
471+
return result
472+
473+
474+
def associative_scan(f, elems, reverse=False, axis=0):
475+
# Ref: jax.lax.associative_scan
476+
if not callable(f):
477+
raise TypeError(f"`f` should be a callable. Received: f={f}")
478+
elems_flat = tree.flatten(elems)
479+
elems_flat = [convert_to_tensor(elem) for elem in elems_flat]
480+
if reverse:
481+
elems_flat = [flip(elem, (axis,)) for elem in elems_flat]
482+
483+
def _combine(a_flat, b_flat):
484+
a = tree.pack_sequence_as(elems, a_flat)
485+
b = tree.pack_sequence_as(elems, b_flat)
486+
c = f(a, b)
487+
c_flat = tree.flatten(c)
488+
return c_flat
489+
490+
num_elems = int(elems_flat[0].shape[axis])
491+
if not all(int(elem.shape[axis]) == num_elems for elem in elems_flat[1:]):
492+
raise ValueError(
493+
"Array inputs to associative_scan must have the same "
494+
"first dimension. (saw: {})".format(
495+
[elem.shape for elem in elems_flat]
496+
)
497+
)
498+
499+
def _interleave(a, b, axis):
500+
"""Given two Tensors of static shape, interleave them along axis."""
501+
assert (
502+
a.shape[axis] == b.shape[axis] or a.shape[axis] == b.shape[axis] + 1
503+
)
504+
505+
# we want to get a: [a1, a2], b: [b1, b2]
506+
# to a: [a1, 0, a2, 0], b: [0, b1, 0, b2]
507+
a_dil = dilate(a, axis, 2)
508+
b_dil = dilate(b, axis, 2)
509+
510+
a_pad = [[0, 0] for _ in range(a.ndim)]
511+
a_pad[axis][-1] = 1 if a.shape[axis] == b.shape[axis] else 0
512+
513+
b_pad = [[0, 0] for _ in range(b.ndim)]
514+
b_pad[axis] = [1, 0] if a.shape[axis] == b.shape[axis] else [1, 1]
515+
516+
op = mx.bitwise_or if a.dtype == mx.bool_ else mx.add
517+
return op(
518+
mx.pad(a_dil, a_pad),
519+
mx.pad(b_dil, b_pad),
520+
)
521+
522+
def _scan(elems):
523+
num_elems = elems[0].shape[axis]
524+
if num_elems < 2:
525+
return elems
526+
527+
reduced_elems = _combine(
528+
[
529+
slice_along_axis(elem, 0, -1, step=2, axis=axis)
530+
for elem in elems
531+
],
532+
[
533+
slice_along_axis(elem, 1, None, step=2, axis=axis)
534+
for elem in elems
535+
],
536+
)
537+
538+
odd_elems = _scan(reduced_elems)
539+
if num_elems % 2 == 0:
540+
even_elems = _combine(
541+
[slice_along_axis(e, 0, -1, axis=axis) for e in odd_elems],
542+
[
543+
slice_along_axis(e, 2, None, step=2, axis=axis)
544+
for e in elems
545+
],
385546
)
386547
else:
387-
# Single output case
388-
outputs = mx.stack(outputs_list)
548+
even_elems = _combine(
549+
odd_elems,
550+
[
551+
slice_along_axis(e, 2, None, step=2, axis=axis)
552+
for e in elems
553+
],
554+
)
389555

390-
if reverse:
391-
if isinstance(outputs, tuple):
392-
outputs = tuple(reverse_sequence(out) for out in outputs)
393-
else:
394-
outputs = reverse_sequence(outputs)
556+
even_elems = [
557+
mx.concatenate(
558+
[slice_along_axis(elem, 0, 1, axis=axis), result],
559+
axis=axis,
560+
)
561+
for (elem, result) in zip(elems, even_elems)
562+
]
563+
return list(
564+
builtins.map(
565+
functools.partial(_interleave, axis=axis), even_elems, odd_elems
566+
)
567+
)
395568

396-
return states, outputs
569+
scans = _scan(elems_flat)
570+
if reverse:
571+
scans = [flip(scanned, (axis,)) for scanned in scans]
572+
573+
return tree.pack_sequence_as(elems, scans)
574+
575+
576+
class custom_gradient:
577+
"""Decorator for custom gradients.
578+
579+
Args:
580+
fun: Forward pass function.
581+
"""
582+
583+
def __init__(self, fun):
584+
warnings.warn(
585+
"`custom_gradient` for the mlx backend acts as a pass-through to "
586+
"support the forward pass. No gradient computation or modification "
587+
"takes place."
588+
)
589+
self.fun = fun
397590

398-
return states, None
591+
def __call__(self, *args, **kwargs):
592+
outputs, _ = self.fun(*args, **kwargs)
593+
return outputs

0 commit comments

Comments
 (0)