Skip to content

Commit c77a339

Browse files
authored
mlx - trainer update for compilation and data adapters (#20957)
* mlx trainer update and data adapters * formatting
1 parent aad52f1 commit c77a339

28 files changed

+570
-248
lines changed

keras/src/backend/common/dtypes_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,15 @@ def setUp(self):
4747

4848
self.jax_enable_x64 = enable_x64()
4949
self.jax_enable_x64.__enter__()
50+
if backend.backend() == "mlx":
51+
self.mlx_cpu_context = backend.core.enable_float64()
52+
self.mlx_cpu_context.__enter__()
5053
return super().setUp()
5154

5255
def tearDown(self):
5356
self.jax_enable_x64.__exit__(None, None, None)
57+
if backend.backend() == "mlx":
58+
self.mlx_cpu_context.__exit__(None, None, None)
5459
return super().tearDown()
5560

5661
@parameterized.named_parameters(

keras/src/backend/common/variables_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,13 @@ def test_standardize_dtype(self, dtype):
256256
):
257257
self.skipTest(f"openvino backend does not support dtype {dtype}")
258258

259+
if backend.backend() == "mlx" and dtype in (
260+
"complex128",
261+
"float8_e4m3fn",
262+
"float8_e5m2",
263+
):
264+
self.skipTest(f"mlx backend does not support dtype {dtype}")
265+
259266
x = backend.convert_to_tensor(np.zeros(()), dtype)
260267
actual = standardize_dtype(x.dtype)
261268
self.assertEqual(actual, dtype)
@@ -805,6 +812,8 @@ def test_invalid_float(self):
805812
elif backend.backend() == "openvino":
806813
# TODO: openvino doesn't support complex
807814
ALL_DTYPES = [x for x in ALL_DTYPES if x not in ["complex128", "complex64"]]
815+
elif backend.backend() == "mlx":
816+
ALL_DTYPES = [x for x in ALL_DTYPES if x not in ["complex128"]]
808817
# Remove float8 dtypes for the following tests
809818
ALL_DTYPES = [x for x in ALL_DTYPES if x not in dtypes.FLOAT8_TYPES]
810819
NON_COMPLEX_DTYPES = [x for x in ALL_DTYPES if x and x not in COMPLEX_DTYPES]
@@ -818,10 +827,15 @@ def setUp(self):
818827

819828
self.jax_enable_x64 = enable_x64()
820829
self.jax_enable_x64.__enter__()
830+
if backend.backend() == "mlx":
831+
self.mlx_cpu_context = backend.core.enable_float64()
832+
self.mlx_cpu_context.__enter__()
821833
return super().setUp()
822834

823835
def tearDown(self):
824836
self.jax_enable_x64.__exit__(None, None, None)
837+
if backend.backend() == "mlx":
838+
self.mlx_cpu_context.__exit__(None, None, None)
825839
return super().tearDown()
826840

827841
@parameterized.named_parameters(

keras/src/backend/mlx/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from keras.src.backend.mlx.core import cond
1818
from keras.src.backend.mlx.core import convert_to_numpy
1919
from keras.src.backend.mlx.core import convert_to_tensor
20+
from keras.src.backend.mlx.core import device_scope
2021
from keras.src.backend.mlx.core import is_tensor
2122
from keras.src.backend.mlx.core import random_seed_dtype
2223
from keras.src.backend.mlx.core import scatter

keras/src/backend/mlx/core.py

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
MLX_DTYPES = {
2626
"float16": mx.float16,
2727
"float32": mx.float32,
28-
"float64": None, # mlx only supports float64 on cpu
28+
"float64": mx.float64, # for mlx float64 only supported on cpu
2929
"uint8": mx.uint8,
3030
"uint16": mx.uint16,
3131
"uint32": mx.uint32,
@@ -104,19 +104,7 @@ def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
104104
return mx.array(x, dtype=mlx_dtype)
105105

106106
if isinstance(x, list):
107-
108-
def to_scalar_list(x):
109-
if isinstance(x, list):
110-
return [to_scalar_list(xi) for xi in x]
111-
elif isinstance(x, mx.array):
112-
if x.ndim == 0:
113-
return x.item()
114-
else:
115-
return x.tolist()
116-
else:
117-
return x
118-
119-
return mx.array(to_scalar_list(x), dtype=mlx_dtype)
107+
return mx.array(x, dtype=mlx_dtype)
120108

121109
if _is_h5py_dataset(x):
122110
if h5py is None:
@@ -592,3 +580,41 @@ def __init__(self, fun):
592580
def __call__(self, *args, **kwargs):
593581
outputs, _ = self.fun(*args, **kwargs)
594582
return outputs
583+
584+
585+
def enable_float64():
586+
"""Returns context manager forcing operations on cpu
587+
588+
MLX requires operations involving `float64` to be on cpu,
589+
mimicking jax's `enable_x64()`
590+
591+
Usage:
592+
```
593+
a = mx.array([1, 2, 3], dtype=mx.float64)
594+
b = mx.array([4, 5, 6], dtype=mx.float64)
595+
596+
with enable_float64():
597+
c = mx.add(a, b)
598+
599+
# OR
600+
mlx_cpu_context = mx.stream(mx.cpu)
601+
mlx_cpu_context.__enter__()
602+
c = mx.add(a, b)
603+
mlx_cpu_context.__exit__(None, None, None)
604+
```
605+
"""
606+
return mx.stream(mx.cpu)
607+
608+
609+
def device_scope(device_name):
610+
if isinstance(device_name, str):
611+
mlx_device = mx.cpu if "cpu" in device_name.lower() else mx.gpu
612+
elif not isinstance(device_name, mx.Device):
613+
raise ValueError(
614+
"Invalid value for argument `device_name`. "
615+
"Expected a string like 'gpu:0' or a `mlx.core.Device` instance. "
616+
f"Received: device_name='{device_name}'"
617+
)
618+
else:
619+
mlx_device = device_name
620+
return mx.stream(mlx_device)

keras/src/backend/mlx/numpy.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -291,10 +291,16 @@ def right_shift(x, y):
291291

292292

293293
def _bincount_1d(x, weights=None, minlength=0):
294-
length = builtins.max(builtins.max(x) + 1, minlength or 0)
294+
x_max = mx.max(x)
295+
length = mx.maximum(x_max + 1, minlength)
296+
295297
counts = mx.zeros(length)
296-
w = weights if weights is not None else mx.ones_like(x)
297-
return counts.at[x].add(w)
298+
if weights is None:
299+
counts = counts.at[x].add(1)
300+
else:
301+
counts = counts.at[x].add(weights)
302+
303+
return counts
298304

299305

300306
def bincount(x, weights=None, minlength=0, sparse=False):

keras/src/backend/mlx/random.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def _sample_gamma(shape, a, b, key):
226226
def binomial(shape, counts, probabilities, dtype=None, seed=None):
227227
# Binomial(n, p) distribution by summing n Bernoulli(p) samples
228228
dtype = to_mlx_dtype(dtype)
229-
key = mx.random.key(seed)
229+
key = mlx_draw_seed(seed)
230230

231231
if isinstance(shape, int):
232232
shape = (shape,)

0 commit comments

Comments
 (0)