Skip to content

Commit f337bb4

Browse files
authored
mlx updates post compilation (#20996)
1 parent 89f6103 commit f337bb4

File tree

11 files changed

+218
-46
lines changed

11 files changed

+218
-46
lines changed

keras/src/backend/mlx/core.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,11 @@ def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
114114
# load h5py._hl.dataset.Dataset object with numpy
115115
x = np.array(x)
116116

117+
if x is None:
118+
# this is needed for tracking
119+
# mlx.array returns a TypeError when called with None
120+
raise ValueError("mlx cannot convert `None` to array")
121+
117122
return mx.array(x, dtype=mlx_dtype)
118123

119124

keras/src/backend/mlx/export.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,45 @@
1+
from keras.src import layers
2+
from keras.src import tree
3+
4+
15
class MlxExportArchive:
6+
def __init__(self):
7+
self._backend_variables = []
8+
self._backend_trainable_variables = []
9+
self._backend_non_trainable_variables = []
10+
211
def track(self, resource):
3-
raise NotImplementedError(
4-
"`track` is not implemented in the mlx backend."
5-
)
12+
if not isinstance(resource, layers.Layer):
13+
raise ValueError(
14+
"Invalid resource type. Expected an instance of a "
15+
"MLX-based Keras `Layer` or `Model`. "
16+
f"Received instead an object of type '{type(resource)}'. "
17+
f"Object received: {resource}"
18+
)
19+
20+
if isinstance(resource, layers.Layer):
21+
# Variables in the lists below are actually part of the trackables
22+
# that get saved, because the lists are created in __init__.
23+
trainable_variables = resource.trainable_variables
24+
non_trainable_variables = resource.non_trainable_variables
25+
26+
self._tf_trackable.trainable_variables += tree.map_structure(
27+
self._convert_to_tf_variable, trainable_variables
28+
)
29+
self._tf_trackable.non_trainable_variables += tree.map_structure(
30+
self._convert_to_tf_variable, non_trainable_variables
31+
)
32+
self._tf_trackable.variables = (
33+
self._tf_trackable.trainable_variables
34+
+ self._tf_trackable.non_trainable_variables
35+
)
36+
37+
self._backend_trainable_variables += trainable_variables
38+
self._backend_non_trainable_variables += non_trainable_variables
39+
self._backend_variables = (
40+
self._backend_trainable_variables
41+
+ self._backend_non_trainable_variables
42+
)
643

744
def add_endpoint(self, name, fn, input_signature=None, **kwargs):
845
raise NotImplementedError(

keras/src/backend/mlx/math.py

Lines changed: 24 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import math
2-
import operator
32

43
import mlx.core as mx
5-
import numpy as np
64

75
from keras.src.backend import standardize_dtype
6+
from keras.src.backend.common.backend_utils import canonicalize_axis
87
from keras.src.backend.mlx.core import convert_to_tensor
98
from keras.src.backend.mlx.linalg import det
109
from keras.src.utils.module_utils import scipy
@@ -23,26 +22,31 @@ def _segment_reduction_fn(
2322
if num_segments is None:
2423
num_segments = mx.max(segment_ids) + 1
2524

26-
valid_indices = segment_ids >= 0
27-
valid_data = mx.array(
28-
np.array(data)[valid_indices] # MLX does not support boolean indices
29-
)
30-
valid_segment_ids = mx.array(np.array(segment_ids)[valid_indices])
31-
32-
data_shape = list(valid_data.shape)
33-
data_shape[0] = num_segments
25+
mask = segment_ids >= 0
26+
# pack segment_ids < 0 into index 0 and then handle below
27+
safe_segment_ids = mx.where(mask, segment_ids, 0)
3428

3529
if not sorted:
36-
sort_indices = mx.argsort(valid_segment_ids)
37-
valid_segment_ids = valid_segment_ids[sort_indices]
38-
valid_data = valid_data[sort_indices]
30+
sort_indices = mx.argsort(safe_segment_ids)
31+
safe_segment_ids = mx.take(safe_segment_ids, sort_indices)
32+
data = mx.take(data, sort_indices, axis=0)
33+
mask = mx.take(mask, sort_indices)
34+
35+
# expand mask dimensions to match data dimensions
36+
for i in range(1, len(data.shape)):
37+
mask = mx.expand_dims(mask, axis=i)
38+
39+
data_shape = list(data.shape)
40+
data_shape[0] = num_segments
3941

4042
if reduction_method == "max":
41-
result = mx.ones(data_shape, dtype=valid_data.dtype) * -mx.inf
42-
result = result.at[valid_segment_ids].maximum(valid_data)
43+
masked_data = mx.where(mask, data, -mx.inf)
44+
result = mx.ones(data_shape, dtype=data.dtype) * -mx.inf
45+
result = result.at[safe_segment_ids].maximum(masked_data)
4346
else: # sum
44-
result = mx.zeros(data_shape, dtype=valid_data.dtype)
45-
result = result.at[valid_segment_ids].add(valid_data)
47+
masked_data = mx.where(mask, data, 0)
48+
result = mx.zeros(data_shape, dtype=data.dtype)
49+
result = result.at[safe_segment_ids].add(masked_data)
4650

4751
return result
4852

@@ -154,19 +158,6 @@ def irfft(x, fft_length=None):
154158
return real_output
155159

156160

157-
def _canonicalize_axis(axis, num_dims):
158-
# Ref: jax.scipy.signal.stft
159-
"""Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims)."""
160-
axis = operator.index(axis)
161-
if not -num_dims <= axis < num_dims:
162-
raise ValueError(
163-
f"axis {axis} is out of bounds for array of dimension {num_dims}"
164-
)
165-
if axis < 0:
166-
axis = axis + num_dims
167-
return axis
168-
169-
170161
def _create_sliding_windows(x, window_size, step):
171162
batch_size, signal_length, _ = x.shape
172163
num_windows = (signal_length - window_size) // step + 1
@@ -187,7 +178,7 @@ def _create_sliding_windows(x, window_size, step):
187178

188179
def _stft(x, window, nperseg, noverlap, nfft, axis=-1):
189180
# Ref: jax.scipy.signal.stft
190-
axis = _canonicalize_axis(axis, x.ndim)
181+
axis = canonicalize_axis(axis, x.ndim)
191182
result_dtype = mx.complex64
192183

193184
if x.size == 0:
@@ -364,8 +355,8 @@ def _istft(
364355
# Ref: jax.scipy.signal.istft
365356
if Zxx.ndim < 2:
366357
raise ValueError("Input stft must be at least 2d!")
367-
freq_axis = _canonicalize_axis(freq_axis, Zxx.ndim)
368-
time_axis = _canonicalize_axis(time_axis, Zxx.ndim)
358+
freq_axis = canonicalize_axis(freq_axis, Zxx.ndim)
359+
time_axis = canonicalize_axis(time_axis, Zxx.ndim)
369360

370361
if freq_axis == time_axis:
371362
raise ValueError("Must specify differing time and frequency axes!")

keras/src/backend/mlx/nn.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1295,9 +1295,6 @@ def dot_product_attention(
12951295
key = convert_to_tensor(key)
12961296
value = convert_to_tensor(value)
12971297

1298-
query = convert_to_tensor(query)
1299-
key = convert_to_tensor(key)
1300-
value = convert_to_tensor(value)
13011298
if len(query.shape) != 4:
13021299
raise ValueError(
13031300
"`dot_product_attention` only supports 4D inputs. "

keras/src/layers/core/einsum_dense_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,10 @@ def test_lora_rank_argument(self):
382382

383383
# Test quantization-related (int8 and float8) methods
384384

385+
@pytest.mark.skipif(
386+
backend.backend() == "mlx",
387+
reason="mlx backend doesn't int8 matmul.",
388+
)
385389
def test_quantize_int8(self):
386390
layer = layers.EinsumDense(
387391
equation="ab,bcd->acd",
@@ -470,6 +474,10 @@ def test_quantize_int8(self):
470474
("btd,ndh->btnh", "btd,ndh->btnh", (None, 2, 8), (1, 2, 4)),
471475
("btd,df->btf", "btd,df->btf", (None, 4), (1, 2, 4)),
472476
)
477+
@pytest.mark.skipif(
478+
backend.backend() == "mlx",
479+
reason="mlx backend doesn't int8 matmul.",
480+
)
473481
def test_quantize_int8_with_specific_equations(
474482
self, equation, output_shape, input_shape
475483
):
@@ -608,6 +616,11 @@ def test_quantize_invalid_mode(self, mode):
608616
def test_quantize_dtype_argument(
609617
self, dtype, num_trainable_weights, num_non_trainable_weights
610618
):
619+
if backend.backend() == "mlx":
620+
if "int8" in dtype:
621+
self.skipTest("mlx backend doesn't support int8 matmul")
622+
if "float8" in dtype:
623+
self.skipTest("mlx backend doesn't support float8")
611624
self.run_layer_test(
612625
layers.EinsumDense,
613626
init_kwargs={
@@ -630,6 +643,10 @@ def test_quantize_dtype_argument(
630643
("btd,ndh->btnh", "btd,ndh->btnh", (1, 4, 32), (1, 4, 8, 16)),
631644
)
632645
@pytest.mark.requires_trainable_backend
646+
@pytest.mark.skipif(
647+
backend.backend() == "mlx",
648+
reason="mlx backend doesn't int8 matmul.",
649+
)
633650
def test_quantize_int8_when_lora_enabled(
634651
self, equation, input_shape, output_shape
635652
):

keras/src/layers/normalization/spectral_normalization.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from keras.src import backend
12
from keras.src import initializers
23
from keras.src import ops
34
from keras.src.api_export import keras_export
@@ -76,17 +77,33 @@ def build(self, input_shape):
7677

7778
def call(self, inputs, training=False):
7879
if training:
79-
new_vector_u, new_kernel = ops.cond(
80-
ops.all(ops.equal(self.kernel.value, 0)),
81-
lambda: (self.vector_u.value, self.kernel.value),
82-
self.normalized_weights,
83-
)
80+
if backend.backend() == "mlx":
81+
# ops.cond is non-compilable with mlx backend
82+
new_vector_u, new_kernel = self._mlx_get_kernel_update()
83+
else:
84+
new_vector_u, new_kernel = ops.cond(
85+
ops.all(ops.equal(self.kernel.value, 0)),
86+
lambda: (self.vector_u.value, self.kernel.value),
87+
self.normalized_weights,
88+
)
8489
self.vector_u.assign(new_vector_u)
8590
self.kernel.assign(new_kernel)
8691

8792
output = self.layer(inputs)
8893
return ops.cast(output, inputs.dtype)
8994

95+
def _mlx_get_kernel_update(self):
96+
kernel_all_zero = ops.all(ops.equal(self.kernel.value, 0))
97+
kernel_all_zero = ops.stop_gradient(kernel_all_zero)
98+
normalized_vector_u, normalized_kernel = self.normalized_weights()
99+
new_vector_u = ops.where(
100+
kernel_all_zero, self.vector_u.value, normalized_vector_u
101+
)
102+
new_kernel = ops.where(
103+
kernel_all_zero, self.kernel.value, normalized_kernel
104+
)
105+
return new_vector_u, new_kernel
106+
90107
def compute_output_shape(self, input_shape):
91108
return self.layer.compute_output_shape(input_shape)
92109

keras/src/layers/preprocessing/normalization.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ def __init__(
122122
f"must be set. Received: mean={mean} and variance={variance}"
123123
)
124124

125+
self._mlx_inputs_captured = False
126+
125127
def build(self, input_shape):
126128
if input_shape is None:
127129
return
@@ -297,6 +299,19 @@ def finalize_state(self):
297299
self.variance = ops.reshape(self.adapt_variance, self._broadcast_shape)
298300
self.variance = ops.cast(self.variance, self.compute_dtype)
299301

302+
def _mlx_capture_inputs(self):
303+
# due to mlx's lazy evaluation
304+
# when compiled, the mean and variance need to be evaluated
305+
# or the values will not be captured and an error thrown
306+
if self._mlx_inputs_captured:
307+
return
308+
309+
from keras.src.utils.module_utils import mlx
310+
311+
mlx.core.eval(self.mean)
312+
mlx.core.eval(self.variance)
313+
self._mlx_inputs_captured = True
314+
300315
def call(self, inputs):
301316
# This layer can be called in tf.data
302317
# even with another backend after it has been adapted.
@@ -314,6 +329,8 @@ def call(self, inputs):
314329
# possible to cause breakage when using this layer in tf.data.
315330
mean = self.convert_weight(self.mean)
316331
variance = self.convert_weight(self.variance)
332+
if self.backend.name == "mlx":
333+
self._mlx_capture_inputs()
317334
if self.invert:
318335
return self.backend.numpy.add(
319336
mean,

keras/src/layers/reshaping/up_sampling2d_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,9 @@ def test_upsampling_2d_correctness(self):
128128
def test_upsampling_2d_various_interpolation_methods(self):
129129
input_shape = (2, 2, 1, 3)
130130
x = np.arange(np.prod(input_shape)).reshape(input_shape)
131+
if backend.backend() == "mlx":
132+
# mlx does not support integer matmul
133+
x = x.astype("float32")
131134
for interpolation in ["nearest", "bilinear", "bicubic"]:
132135
layers.UpSampling2D(size=(1, 2), interpolation=interpolation)(x)
133136

keras/src/losses/losses.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1518,6 +1518,12 @@ def _return_labels_unconverted():
15181518
# Returns the labels unchanged if they are non-binary
15191519
return y_true
15201520

1521+
if backend.backend() == "mlx":
1522+
# ops.cond is non-compilable with mlx backend
1523+
return ops.where(
1524+
is_binary, _convert_binary_labels(), _return_labels_unconverted()
1525+
)
1526+
15211527
updated_y_true = ops.cond(
15221528
is_binary, _convert_binary_labels, _return_labels_unconverted
15231529
)

keras/src/models/model_test.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -779,6 +779,12 @@ def test_functional_list_outputs_invalid_nested_list_losses(self):
779779
("float8", "float8"),
780780
)
781781
def test_quantize(self, mode):
782+
if backend.backend() == "mlx":
783+
self.skipTest(
784+
"mlx backend does not support float8."
785+
if mode == "float8"
786+
else "mlx backend does not support integer matmul"
787+
)
782788
model = _get_model()
783789
x1 = np.random.rand(2, 3)
784790
x2 = np.random.rand(2, 3)
@@ -1232,8 +1238,8 @@ def test_export_error(self):
12321238
with self.assertRaisesRegex(
12331239
NotImplementedError,
12341240
(
1235-
r"`export_saved_model` only currently supports the "
1236-
r"tensorflow, jax and torch backends."
1241+
r"`ExportArchive` is only compatible with "
1242+
r"TensorFlow, JAX and Torch backends."
12371243
),
12381244
):
12391245
model.export(temp_filepath, format="tf_saved_model")

0 commit comments

Comments
 (0)