Skip to content

Commit dc9ed28

Browse files
authored
mlx - data adapters and saving (#21023)
* updates to data adapters with mlx arrays and other backends and keras saving * formatting
1 parent f337bb4 commit dc9ed28

File tree

6 files changed

+45
-4
lines changed

6 files changed

+45
-4
lines changed

keras/src/backend/mlx/core.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,9 @@ def convert_to_tensors(*xs):
153153

154154
def convert_to_numpy(x):
155155
# Performs a copy. If we want 0-copy we can pass copy=False
156+
if isinstance(x, mx.array) and x.dtype == mx.bfloat16:
157+
# mlx currently has an error passing bloat16 array to numpy
158+
return np.array(x.astype(mx.float32))
156159
return np.array(x)
157160

158161

keras/src/saving/saving_api_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from absl import logging
66
from absl.testing import parameterized
77

8+
from keras.src import backend
89
from keras.src import layers
910
from keras.src.models import Sequential
1011
from keras.src.saving import saving_api
@@ -122,6 +123,11 @@ def get_model(self, dtype=None):
122123
)
123124
def test_basic_load(self, dtype):
124125
"""Test basic model loading."""
126+
if backend.backend() == "mlx" and dtype == "float64":
127+
self.skipTest(
128+
"mlx backend does not yet support float64 in random and uniform"
129+
)
130+
125131
model = self.get_model(dtype)
126132
filepath = os.path.join(self.get_temp_dir(), "test_model.keras")
127133
saving_api.save_model(model, filepath)
@@ -208,6 +214,13 @@ def get_model(self, dtype=None):
208214
)
209215
def test_load_keras_weights(self, source_dtype, dest_dtype):
210216
"""Test loading keras weights."""
217+
if backend.backend() == "mlx":
218+
if source_dtype == "float64" or dest_dtype == "float64":
219+
self.skipTest(
220+
"mlx backend does not yet support float64 in "
221+
"random and uniform"
222+
)
223+
211224
src_model = self.get_model(dtype=source_dtype)
212225
filepath = os.path.join(self.get_temp_dir(), "test_weights.weights.h5")
213226
src_model.save_weights(filepath)

keras/src/saving/serialization_lib_test.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pytest
77

88
import keras
9+
from keras.src import backend
910
from keras.src import ops
1011
from keras.src import testing
1112
from keras.src.saving import serialization_lib
@@ -112,7 +113,11 @@ def test_serialize_ellipsis(self):
112113
self.assertEqual(..., deserialized)
113114

114115
def test_tensors_and_shapes(self):
115-
x = ops.random.normal((2, 2), dtype="float64")
116+
if backend.backend() == "mlx":
117+
# mlx backend does not yet support float64 in normal and uniform
118+
x = ops.random.normal((2, 2), dtype="float32")
119+
else:
120+
x = ops.random.normal((2, 2), dtype="float64")
116121
obj = {"x": x}
117122
_, new_obj, _ = self.roundtrip(obj)
118123
self.assertAllClose(x, new_obj["x"], atol=1e-5)

keras/src/trainers/data_adapters/array_slicing.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,26 @@ def __getitem__(self, indices):
146146

147147
return self.array[mx.array(indices)]
148148

149+
@classmethod
150+
def cast(cls, x, dtype):
151+
from keras.src.backend.mlx.core import cast
152+
153+
return cast(x, dtype)
154+
155+
@classmethod
156+
def convert_to_numpy(cls, x):
157+
from keras.src.backend.mlx.core import convert_to_numpy
158+
159+
return convert_to_numpy(x)
160+
161+
@classmethod
162+
def convert_to_jax_compatible(cls, x):
163+
return cls.convert_to_numpy(x)
164+
165+
@classmethod
166+
def convert_to_tf_dataset_compatible(cls, x):
167+
return cls.convert_to_numpy(x)
168+
149169

150170
class TensorflowSliceable(Sliceable):
151171
def __getitem__(self, indices):

keras/src/trainers/data_adapters/data_adapter_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,8 @@ def convert_to_numpy(x):
222222
if is_torch_tensor(x):
223223
x = x.cpu()
224224
x = np.asarray(x)
225+
if is_mlx_array(x):
226+
x = np.array(x)
225227
return x
226228

227229
for batch in iterable:

keras/src/trainers/data_adapters/tf_dataset_adapter.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,7 @@ def get_mlx_iterator(self):
6262
def convert_to_mlx(x):
6363
if isinstance(x, tf.SparseTensor):
6464
x = sparse_to_dense(x)
65-
# tensorflow supports the buffer protocol
66-
# but requires explicit memoryview with mlx
67-
return mlx.core.array(memoryview(x))
65+
return mlx.core.array(x)
6866

6967
for batch in self._dataset:
7068
yield tree.map_structure(convert_to_mlx, batch)

0 commit comments

Comments
 (0)