Skip to content

Commit cc2f446

Browse files
authored
Adds initial support for an MLX backend (#18962)
* Start the mlx backend * Add the numpy part of the backend and a native accessor * Start nn and fix core and trainer * Add mlx to the trainer test * Small fix in mlx/core * Fixes to pass the tests * Add floor in mlx backend * Fix the styles * Add several missing ops in MLX's nn * Fix formatting * Change mlx to return tuple shapes and add a few ops * Add definitions for more ops in the mlx backend * Add affine transformations * Add uint64 and int64 * Export rnn things from the backend init * Fix some errors * Fix diagonal indices and change exceptions to NYI * Enable the mlx backend in the tests * Update some things for mlx v0.2.0 * Formatter updates * Fix for python 3.9 * Add the Github workflow config * Add mlx in the import test * Fix import in mlx.math
1 parent e58af3e commit cc2f446

File tree

18 files changed

+3091
-2
lines changed

18 files changed

+3091
-2
lines changed

.github/workflows/actions.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
fail-fast: false
1717
matrix:
1818
python-version: [3.9]
19-
backend: [tensorflow, jax, torch, numpy]
19+
backend: [tensorflow, jax, torch, numpy, mlx]
2020
name: Run tests
2121
runs-on: ubuntu-latest
2222
env:
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
"floatx": "float32",
3+
"epsilon": 1e-07,
4+
"backend": "mlx",
5+
"image_data_format": "channels_last"
6+
}
7+

integration_tests/import_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
"tensorflow": "tensorflow",
99
"torch": "torch torchvision",
1010
"jax": "jax jaxlib",
11+
"mlx": "mlx",
1112
}
1213

1314

keras/backend/common/variables.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,9 @@ def standardize_dtype(dtype):
387387
if hasattr(dtype, "name"):
388388
dtype = dtype.name
389389
elif hasattr(dtype, "__str__") and (
390-
"torch" in str(dtype) or "jax.numpy" in str(dtype)
390+
"torch" in str(dtype)
391+
or "jax.numpy" in str(dtype)
392+
or "mlx" in str(dtype)
391393
):
392394
dtype = str(dtype).split(".")[-1]
393395
elif hasattr(dtype, "__name__"):

keras/backend/mlx/__init__.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
"""MLX backend APIs."""
2+
3+
from keras.backend.mlx import core
4+
from keras.backend.mlx import image
5+
from keras.backend.mlx import math
6+
from keras.backend.mlx import nn
7+
from keras.backend.mlx import numpy
8+
from keras.backend.mlx import random
9+
from keras.backend.mlx.core import SUPPORTS_SPARSE_TENSORS
10+
from keras.backend.mlx.core import Variable
11+
from keras.backend.mlx.core import cast
12+
from keras.backend.mlx.core import compute_output_spec
13+
from keras.backend.mlx.core import cond
14+
from keras.backend.mlx.core import convert_to_numpy
15+
from keras.backend.mlx.core import convert_to_tensor
16+
from keras.backend.mlx.core import is_tensor
17+
from keras.backend.mlx.core import scatter
18+
from keras.backend.mlx.core import shape
19+
from keras.backend.mlx.core import stop_gradient
20+
from keras.backend.mlx.core import to_mlx_dtype
21+
from keras.backend.mlx.core import vectorized_map
22+
from keras.backend.mlx.rnn import cudnn_ok
23+
from keras.backend.mlx.rnn import gru
24+
from keras.backend.mlx.rnn import lstm
25+
from keras.backend.mlx.rnn import rnn

keras/backend/mlx/core.py

Lines changed: 290 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,290 @@
1+
import mlx.core as mx
2+
import numpy as np
3+
import tree
4+
5+
from keras.backend.common import KerasVariable
6+
from keras.backend.common import standardize_dtype
7+
from keras.backend.common.keras_tensor import KerasTensor
8+
from keras.backend.common.stateless_scope import StatelessScope
9+
from keras.utils.nest import pack_sequence_as
10+
11+
SUPPORTS_SPARSE_TENSORS = False
12+
13+
MLX_DTYPES = {
14+
"float16": mx.float16,
15+
"float32": mx.float32,
16+
"float64": None, # mlx does not support float64
17+
"uint8": mx.uint8,
18+
"uint16": mx.uint16,
19+
"uint32": mx.uint32,
20+
"uint64": mx.uint64,
21+
"int8": mx.int8,
22+
"int16": mx.int16,
23+
"int32": mx.int32,
24+
"int64": mx.int64,
25+
"bfloat16": mx.bfloat16,
26+
"bool": mx.bool_,
27+
}
28+
29+
30+
def to_mlx_dtype(dtype):
31+
if isinstance(dtype, mx.Dtype):
32+
return dtype
33+
standardized_dtype = MLX_DTYPES.get(standardize_dtype(dtype), None)
34+
if standardized_dtype is None:
35+
raise ValueError(f"Unsupported dtype for MLX: {dtype}")
36+
return standardized_dtype
37+
38+
39+
class Variable(KerasVariable):
40+
def _initialize(self, value):
41+
self._value = convert_to_tensor(value, dtype=self._dtype)
42+
43+
def _direct_assign(self, value):
44+
self._value = value
45+
46+
def _convert_to_tensor(self, value, dtype=None):
47+
return convert_to_tensor(value, dtype=dtype)
48+
49+
def __mlx_array__(self):
50+
return self.value
51+
52+
def __array__(self, dtype=None):
53+
value = convert_to_numpy(self._value)
54+
if dtype:
55+
return value.astype(dtype)
56+
return value
57+
58+
59+
def convert_to_tensor(x, dtype=None, sparse=None):
60+
if sparse:
61+
raise ValueError("`sparse=True` is not supported with mlx backend")
62+
mlx_dtype = to_mlx_dtype(dtype) if dtype is not None else None
63+
64+
if is_tensor(x):
65+
if dtype is None:
66+
return x
67+
return x.astype(mlx_dtype)
68+
69+
if isinstance(x, Variable):
70+
if dtype and standardize_dtype(dtype) != x.dtype:
71+
return x.value.astype(mlx_dtype)
72+
return x.value
73+
74+
if isinstance(x, np.ndarray):
75+
if x.dtype == np.int64:
76+
x = x.astype(np.int32)
77+
x = x.astype(standardize_dtype(x.dtype))
78+
return mx.array(x, dtype=mlx_dtype)
79+
80+
if isinstance(x, list):
81+
82+
def to_scalar_list(x):
83+
if isinstance(x, list):
84+
return [to_scalar_list(xi) for xi in x]
85+
elif isinstance(x, mx.array):
86+
if x.ndim == 0:
87+
return x.item()
88+
else:
89+
return x.tolist()
90+
else:
91+
return x
92+
93+
return mx.array(to_scalar_list(x), dtype=mlx_dtype)
94+
95+
return mx.array(x, dtype=mlx_dtype)
96+
97+
98+
def convert_to_tensors(*xs):
99+
ys = [None] * len(xs)
100+
dtype = None
101+
for i, x in enumerate(xs):
102+
if not isinstance(x, (int, float, bool)):
103+
ys[i] = convert_to_tensor(x)
104+
dtype = ys[i].dtype
105+
# Floating point wins so scalars promote to dtype
106+
if dtype in (mx.float32, mx.float16, mx.bfloat16):
107+
for i, x in enumerate(xs):
108+
if ys[i] is None:
109+
ys[i] = mx.array(x, dtype=dtype)
110+
# Bool loses against everything so scalars keep their type
111+
elif dtype == mx.bool_:
112+
for i, x in enumerate(xs):
113+
if ys[i] is None:
114+
ys[i] = mx.array(x)
115+
# Integral types keep their type except if the scalar is a float
116+
else:
117+
for i, x in enumerate(xs):
118+
if ys[i] is None:
119+
if isinstance(x, float):
120+
ys[i] = mx.array(x)
121+
else:
122+
ys[i] = mx.array(x, dtype=dtype)
123+
124+
return ys
125+
126+
127+
def convert_to_numpy(x):
128+
# Performs a copy. If we want 0-copy we can pass copy=False
129+
return np.array(x)
130+
131+
132+
def is_tensor(x):
133+
return isinstance(x, mx.array)
134+
135+
136+
def shape(x):
137+
return tuple(x.shape)
138+
139+
140+
def cast(x, dtype):
141+
return convert_to_tensor(x, dtype=dtype)
142+
143+
144+
# Shape / dtype inference util
145+
def compute_output_spec(fn, *args, **kwargs):
146+
def has_none_shape(x):
147+
"""Check for if a `KerasTensor` has dynamic shape."""
148+
if isinstance(x, KerasTensor):
149+
return None in x.shape
150+
return False
151+
152+
def convert_keras_tensor_to_mlx(x, fill_value=None):
153+
"""Convert `KerasTensor`s to `mlx.array`s."""
154+
if isinstance(x, KerasTensor):
155+
shape = list(x.shape)
156+
if fill_value:
157+
for i, e in enumerate(shape):
158+
if e is None:
159+
shape[i] = fill_value
160+
return mx.ones(shape, dtype=MLX_DTYPES[x.dtype])
161+
return x
162+
163+
def convert_mlx_to_keras_tensor(x):
164+
"""Convert `mlx.array`s to `KerasTensor`s."""
165+
if is_tensor(x):
166+
return KerasTensor(x.shape, standardize_dtype(x.dtype))
167+
return x
168+
169+
def symbolic_call(fn, args, kwargs, fill_value):
170+
"""Call `fn` to infer output shape and dtype."""
171+
arr_args, arr_kwargs = tree.map_structure(
172+
lambda x: convert_keras_tensor_to_mlx(x, fill_value),
173+
(args, kwargs),
174+
)
175+
return fn(*arr_args, **arr_kwargs)
176+
177+
with StatelessScope():
178+
outputs = symbolic_call(fn, args, kwargs, fill_value=83)
179+
180+
none_in_shape = any(map(has_none_shape, tree.flatten((args, kwargs))))
181+
if none_in_shape:
182+
outputs_1 = outputs
183+
outputs_2 = symbolic_call(fn, args, kwargs, fill_value=89)
184+
185+
flat_out_1 = tree.flatten(outputs_1)
186+
flat_out_2 = tree.flatten(outputs_2)
187+
188+
flat_out = []
189+
for x1, x2 in zip(flat_out_1, flat_out_2):
190+
shape = list(x1.shape)
191+
for i, e in enumerate(x2.shape):
192+
if e != shape[i]:
193+
shape[i] = None
194+
flat_out.append(KerasTensor(shape, standardize_dtype(x1.dtype)))
195+
outputs = pack_sequence_as(outputs_1, flat_out)
196+
197+
output_spec = tree.map_structure(convert_mlx_to_keras_tensor, outputs)
198+
return output_spec
199+
200+
201+
def cond(pred, true_fn, false_fn):
202+
# TODO: How should we avoid evaluating pred in case we are tracing?
203+
if pred:
204+
return true_fn()
205+
return false_fn()
206+
207+
208+
def vectorized_map(function, elements):
209+
return mx.vmap(function)(elements)
210+
211+
212+
def scatter(indices, values, shape):
213+
indices = convert_to_tensor(indices)
214+
values = convert_to_tensor(values)
215+
zeros = mx.zeros(shape, dtype=values.dtype)
216+
indices = tuple(indices[..., i] for i in range(indices.shape[-1]))
217+
zeros = zeros.at[indices].add(values)
218+
219+
return zeros
220+
221+
222+
def scatter_update(inputs, indices, updates):
223+
inputs = convert_to_tensor(inputs)
224+
indices = convert_to_tensor(indices)
225+
updates = convert_to_tensor(updates)
226+
indices = tuple(indices[..., i] for i in range(indices.shape[-1]))
227+
inputs[indices] = updates
228+
229+
return inputs
230+
231+
232+
def slice(inputs, start_indices, shape):
233+
inputs = convert_to_tensor(inputs)
234+
235+
python_slice = __builtins__["slice"]
236+
slices = tuple(
237+
python_slice(int(start_index), int(start_index + length))
238+
for start_index, length in zip(start_indices, shape)
239+
)
240+
return inputs[slices]
241+
242+
243+
def slice_update(inputs, start_indices, updates):
244+
inputs = convert_to_tensor(inputs)
245+
updates = convert_to_tensor(updates)
246+
247+
python_slice = __builtins__["slice"]
248+
slices = tuple(
249+
python_slice(int(start_index), int(start_index + update_length))
250+
for start_index, update_length in zip(start_indices, updates.shape)
251+
)
252+
inputs[slices] = updates
253+
return inputs
254+
255+
256+
def while_loop(
257+
cond,
258+
body,
259+
loop_vars,
260+
maximum_iterations=None,
261+
):
262+
# TODO: How should we avoid evaluating cond when tracing?
263+
current_iter = 0
264+
iteration_check = (
265+
lambda iter: maximum_iterations is None or iter < maximum_iterations
266+
)
267+
loop_vars = tuple([convert_to_tensor(v) for v in loop_vars])
268+
while cond(*loop_vars) and iteration_check(current_iter):
269+
loop_vars = body(*loop_vars)
270+
if not isinstance(loop_vars, (list, tuple)):
271+
loop_vars = (loop_vars,)
272+
loop_vars = tuple(loop_vars)
273+
current_iter += 1
274+
return loop_vars
275+
276+
277+
def fori_loop(lower, upper, body_fun, init_val):
278+
val = init_val
279+
for i in range(lower, upper):
280+
val = body_fun(i, val)
281+
return val
282+
283+
284+
def stop_gradient(variable):
285+
return mx.stop_gradient(variable)
286+
287+
288+
def unstack(x, num=None, axis=0):
289+
y = x.split(num or x.shape[axis], axis=axis)
290+
return [yi.squeeze(axis) for yi in y]

0 commit comments

Comments
 (0)