From 0b647e47a28952219b1931ff43d4e2ff2cee341f Mon Sep 17 00:00:00 2001 From: Andrew Sweet Date: Sun, 13 Jul 2025 23:51:46 -0700 Subject: [PATCH 1/2] updates for tests with the mlx backend in src/layers --- .../modeling/reversible_embedding_test.py | 2 ++ .../token_and_position_embedding_test.py | 7 +++++- .../modeling/transformer_decoder_test.py | 25 ++++++++++++++----- .../modeling/transformer_encoder_test.py | 12 ++++++--- .../preprocessing/image_converter_test.py | 12 +++++++-- keras_hub/src/tests/test_case.py | 6 +++++ keras_hub/src/utils/tensor_utils.py | 18 ++++++++++++- requirements.txt | 5 ++++ 8 files changed, 74 insertions(+), 13 deletions(-) diff --git a/keras_hub/src/layers/modeling/reversible_embedding_test.py b/keras_hub/src/layers/modeling/reversible_embedding_test.py index 4482854449..c0ac01a1bd 100644 --- a/keras_hub/src/layers/modeling/reversible_embedding_test.py +++ b/keras_hub/src/layers/modeling/reversible_embedding_test.py @@ -2,6 +2,7 @@ import keras import numpy as np +import pytest from absl.testing import parameterized from keras import ops from keras import random @@ -95,6 +96,7 @@ def test_reverse_dtype(self): @parameterized.named_parameters( ("tie_weights", True), ("untie_weights", False) ) + @pytest.mark.skipif(keras.config.backend() == "mlx", reason="quantization not yet implemented for mlx backend") def test_quantize_int8(self, tie_weights): layer_config = dict( input_dim=100, output_dim=32, tie_weights=tie_weights diff --git a/keras_hub/src/layers/modeling/token_and_position_embedding_test.py b/keras_hub/src/layers/modeling/token_and_position_embedding_test.py index f0ef202aed..542c54d588 100644 --- a/keras_hub/src/layers/modeling/token_and_position_embedding_test.py +++ b/keras_hub/src/layers/modeling/token_and_position_embedding_test.py @@ -1,6 +1,8 @@ +import keras import numpy as np from keras import ops from keras import random +from keras.src import backend from keras_hub.src.layers.modeling.token_and_position_embedding import ( TokenAndPositionEmbedding, @@ -34,4 +36,7 @@ def test_mask_propagation(self): input_data = np.array([[1, 0], [1, 0]]) mask = input_data != 0 outputs = test_layer(input_data) - self.assertAllEqual(outputs._keras_mask, mask) + if keras.config.backend() == "mlx": + self.assertAllEqual(backend.get_keras_mask(outputs), mask) + else: + self.assertAllEqual(outputs._keras_mask, mask) diff --git a/keras_hub/src/layers/modeling/transformer_decoder_test.py b/keras_hub/src/layers/modeling/transformer_decoder_test.py index 7cbd32bed8..c8118ae5d6 100644 --- a/keras_hub/src/layers/modeling/transformer_decoder_test.py +++ b/keras_hub/src/layers/modeling/transformer_decoder_test.py @@ -1,6 +1,8 @@ +import keras from absl.testing import parameterized from keras import ops from keras import random +from keras.src import backend from keras_hub.src.layers.modeling.transformer_decoder import TransformerDecoder from keras_hub.src.tests.test_case import TestCase @@ -114,9 +116,14 @@ def test_mask_propagation(self): decoder_sequence = random.uniform(shape=[1, 4, 6]) encoder_sequence = random.uniform(shape=[1, 4, 6]) mask = ops.array([[True, True, False, False]]) - decoder_sequence._keras_mask = mask - outputs = decoder(decoder_sequence, encoder_sequence) - self.assertAllEqual(outputs._keras_mask, mask) + if keras.config.backend() == "mlx": + backend.set_keras_mask(decoder_sequence, mask) + outputs = decoder(decoder_sequence, encoder_sequence) + self.assertAllEqual(backend.get_keras_mask(outputs), mask) + else: + decoder_sequence._keras_mask = mask + outputs = decoder(decoder_sequence, encoder_sequence) + self.assertAllEqual(outputs._keras_mask, mask) def test_mask_propagation_without_cross_attention(self): decoder = TransformerDecoder( @@ -125,9 +132,15 @@ def test_mask_propagation_without_cross_attention(self): ) decoder_sequence = random.uniform(shape=[1, 4, 6]) mask = ops.array([[True, True, False, False]]) - decoder_sequence._keras_mask = mask - outputs = decoder(decoder_sequence) - self.assertAllEqual(outputs._keras_mask, mask) + + if keras.config.backend() == "mlx": + backend.set_keras_mask(decoder_sequence, mask) + outputs = decoder(decoder_sequence) + self.assertAllEqual(backend.get_keras_mask(outputs), mask) + else: + decoder_sequence._keras_mask = mask + outputs = decoder(decoder_sequence) + self.assertAllEqual(outputs._keras_mask, mask) def test_cache_call_is_correct(self): batch_size, seq_len, num_heads, key_dim = 2, 5, 2, 4 diff --git a/keras_hub/src/layers/modeling/transformer_encoder_test.py b/keras_hub/src/layers/modeling/transformer_encoder_test.py index a682af157c..324d74600a 100644 --- a/keras_hub/src/layers/modeling/transformer_encoder_test.py +++ b/keras_hub/src/layers/modeling/transformer_encoder_test.py @@ -2,6 +2,7 @@ from absl.testing import parameterized from keras import ops from keras import random +from keras.src import backend from keras_hub.src.layers.modeling.transformer_encoder import TransformerEncoder from keras_hub.src.tests.test_case import TestCase @@ -92,9 +93,14 @@ def test_mask_propagation(self): ) inputs = random.uniform(shape=[1, 4, 6]) mask = ops.array([[True, True, False, False]]) - inputs._keras_mask = mask - outputs = encoder(inputs) - self.assertAllEqual(outputs._keras_mask, mask) + if keras.config.backend() == "mlx": + backend.set_keras_mask(inputs, mask) + outputs = encoder(inputs) + self.assertAllEqual(backend.get_keras_mask(outputs), mask) + else: + inputs._keras_mask = mask + outputs = encoder(inputs) + self.assertAllEqual(outputs._keras_mask, mask) def test_attention_scores(self): encoder = TransformerEncoder(intermediate_dim=4, num_heads=2) diff --git a/keras_hub/src/layers/preprocessing/image_converter_test.py b/keras_hub/src/layers/preprocessing/image_converter_test.py index 8d47872a43..b8cb371cba 100644 --- a/keras_hub/src/layers/preprocessing/image_converter_test.py +++ b/keras_hub/src/layers/preprocessing/image_converter_test.py @@ -52,8 +52,16 @@ def test_unbatched(self): def test_dtypes(self): converter = ImageConverter(image_size=(4, 4), scale=1.0 / 255.0) - int_image = ops.ones((10, 10, 3), dtype="uint8") * 255 - float_image = ops.ones((10, 10, 3), dtype="float64") * 255 + if keras.config.backend() == "mlx": + # mlx backend does not support int matmul + int_image = ops.ones((10, 10, 3), dtype="float16") * 255 + # mlx only suports float64 on the cpu + # can force all operations onto cpu stream with float64 + float_image = ops.ones((10, 10, 3), dtype="float32") * 255 + else: + int_image = ops.ones((10, 10, 3), dtype="uint8") * 255 + float_image = ops.ones((10, 10, 3), dtype="float64") * 255 + self.assertDTypeEqual(converter(int_image), "float32") self.assertDTypeEqual(converter(float_image), "float32") self.assertAllClose(converter(int_image), np.ones((4, 4, 3))) diff --git a/keras_hub/src/tests/test_case.py b/keras_hub/src/tests/test_case.py index 43eb8050c3..c41f67f1b8 100644 --- a/keras_hub/src/tests/test_case.py +++ b/keras_hub/src/tests/test_case.py @@ -9,6 +9,7 @@ from absl.testing import parameterized from keras import ops from keras import tree +# from keras.src.trainers.data_adapters import is_mlx_array from keras_hub.src.layers.modeling.reversible_embedding import ( ReversibleEmbedding, @@ -16,6 +17,7 @@ from keras_hub.src.models.retinanet.feature_pyramid import FeaturePyramid from keras_hub.src.tokenizers.tokenizer import Tokenizer from keras_hub.src.utils.tensor_utils import is_float_dtype +from keras_hub.src.utils.tensor_utils import is_mlx_array def convert_to_comparible_type(x): @@ -34,6 +36,10 @@ def convert_to_comparible_type(x): return x if hasattr(x, "__array__"): return ops.convert_to_numpy(x) + if keras.config.backend() == "mlx" and is_mlx_array(x): + # this is to handle bfloat16 + # mlx arrays don't have an __array__ attribute + return ops.convert_to_numpy(x) return x diff --git a/keras_hub/src/utils/tensor_utils.py b/keras_hub/src/utils/tensor_utils.py index 47305a3f01..d1dd9c36d5 100644 --- a/keras_hub/src/utils/tensor_utils.py +++ b/keras_hub/src/utils/tensor_utils.py @@ -245,7 +245,14 @@ def tensor_to_list(inputs): def convert_to_ragged_batch(inputs): """Ensure a tf.Tensor is a ragged rank 2 tensor.""" if not isinstance(inputs, (tf.RaggedTensor, tf.Tensor)): - inputs = tf.convert_to_tensor(inputs) + if keras.config.backend() == "mlx": + # mlx array to tf tensor currently only supports flat arrays + array_shape = inputs.shape + inputs = inputs.flatten() + inputs = tf.convert_to_tensor(memoryview(inputs)) + inputs = tf.reshape(inputs, array_shape) + else: + inputs = tf.convert_to_tensor(inputs) unbatched = inputs.shape.rank == 1 rectangular = isinstance(inputs, tf.Tensor) if unbatched: @@ -321,6 +328,15 @@ def is_string_dtype(dtype): return "string" in keras.backend.standardize_dtype(dtype) +def is_mlx_array(value): + if hasattr(value, "__class__"): + return ( + value.__class__.__module__ == "mlx.core" + and value.__class__.__name__ == "array" + ) + return False + + def get_dtype_size_in_bits(dtype): """Get the size of a given dtype in bits.""" dtype = keras.backend.standardize_dtype(dtype) diff --git a/requirements.txt b/requirements.txt index e499522558..ee2077cf41 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,4 +11,9 @@ torchvision>=0.16.0 # Jax. jax[cpu] +# MLX. +pybind11[global] +cmake +mlx + -r requirements-common.txt From 84075060b3079f9b5daafc6ebc392e345ac59197 Mon Sep 17 00:00:00 2001 From: Andrew Sweet Date: Sat, 19 Jul 2025 00:51:43 -0700 Subject: [PATCH 2/2] formatting --- keras_hub/src/layers/modeling/reversible_embedding_test.py | 5 ++++- keras_hub/src/layers/preprocessing/image_converter_test.py | 2 +- keras_hub/src/tests/test_case.py | 2 +- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/keras_hub/src/layers/modeling/reversible_embedding_test.py b/keras_hub/src/layers/modeling/reversible_embedding_test.py index c0ac01a1bd..d1e2a897a1 100644 --- a/keras_hub/src/layers/modeling/reversible_embedding_test.py +++ b/keras_hub/src/layers/modeling/reversible_embedding_test.py @@ -96,7 +96,10 @@ def test_reverse_dtype(self): @parameterized.named_parameters( ("tie_weights", True), ("untie_weights", False) ) - @pytest.mark.skipif(keras.config.backend() == "mlx", reason="quantization not yet implemented for mlx backend") + @pytest.mark.skipif( + keras.config.backend() == "mlx", + reason="quantization not yet implemented for mlx backend", + ) def test_quantize_int8(self, tie_weights): layer_config = dict( input_dim=100, output_dim=32, tie_weights=tie_weights diff --git a/keras_hub/src/layers/preprocessing/image_converter_test.py b/keras_hub/src/layers/preprocessing/image_converter_test.py index b8cb371cba..e380e298c1 100644 --- a/keras_hub/src/layers/preprocessing/image_converter_test.py +++ b/keras_hub/src/layers/preprocessing/image_converter_test.py @@ -58,7 +58,7 @@ def test_dtypes(self): # mlx only suports float64 on the cpu # can force all operations onto cpu stream with float64 float_image = ops.ones((10, 10, 3), dtype="float32") * 255 - else: + else: int_image = ops.ones((10, 10, 3), dtype="uint8") * 255 float_image = ops.ones((10, 10, 3), dtype="float64") * 255 diff --git a/keras_hub/src/tests/test_case.py b/keras_hub/src/tests/test_case.py index c41f67f1b8..ae95f364fe 100644 --- a/keras_hub/src/tests/test_case.py +++ b/keras_hub/src/tests/test_case.py @@ -9,8 +9,8 @@ from absl.testing import parameterized from keras import ops from keras import tree -# from keras.src.trainers.data_adapters import is_mlx_array +# from keras.src.trainers.data_adapters import is_mlx_array from keras_hub.src.layers.modeling.reversible_embedding import ( ReversibleEmbedding, )