diff --git a/keras_hub/src/layers/modeling/reversible_embedding_test.py b/keras_hub/src/layers/modeling/reversible_embedding_test.py index 4482854449..d1e2a897a1 100644 --- a/keras_hub/src/layers/modeling/reversible_embedding_test.py +++ b/keras_hub/src/layers/modeling/reversible_embedding_test.py @@ -2,6 +2,7 @@ import keras import numpy as np +import pytest from absl.testing import parameterized from keras import ops from keras import random @@ -95,6 +96,10 @@ def test_reverse_dtype(self): @parameterized.named_parameters( ("tie_weights", True), ("untie_weights", False) ) + @pytest.mark.skipif( + keras.config.backend() == "mlx", + reason="quantization not yet implemented for mlx backend", + ) def test_quantize_int8(self, tie_weights): layer_config = dict( input_dim=100, output_dim=32, tie_weights=tie_weights diff --git a/keras_hub/src/layers/modeling/token_and_position_embedding_test.py b/keras_hub/src/layers/modeling/token_and_position_embedding_test.py index f0ef202aed..542c54d588 100644 --- a/keras_hub/src/layers/modeling/token_and_position_embedding_test.py +++ b/keras_hub/src/layers/modeling/token_and_position_embedding_test.py @@ -1,6 +1,8 @@ +import keras import numpy as np from keras import ops from keras import random +from keras.src import backend from keras_hub.src.layers.modeling.token_and_position_embedding import ( TokenAndPositionEmbedding, @@ -34,4 +36,7 @@ def test_mask_propagation(self): input_data = np.array([[1, 0], [1, 0]]) mask = input_data != 0 outputs = test_layer(input_data) - self.assertAllEqual(outputs._keras_mask, mask) + if keras.config.backend() == "mlx": + self.assertAllEqual(backend.get_keras_mask(outputs), mask) + else: + self.assertAllEqual(outputs._keras_mask, mask) diff --git a/keras_hub/src/layers/modeling/transformer_decoder_test.py b/keras_hub/src/layers/modeling/transformer_decoder_test.py index 7cbd32bed8..c8118ae5d6 100644 --- a/keras_hub/src/layers/modeling/transformer_decoder_test.py +++ b/keras_hub/src/layers/modeling/transformer_decoder_test.py @@ -1,6 +1,8 @@ +import keras from absl.testing import parameterized from keras import ops from keras import random +from keras.src import backend from keras_hub.src.layers.modeling.transformer_decoder import TransformerDecoder from keras_hub.src.tests.test_case import TestCase @@ -114,9 +116,14 @@ def test_mask_propagation(self): decoder_sequence = random.uniform(shape=[1, 4, 6]) encoder_sequence = random.uniform(shape=[1, 4, 6]) mask = ops.array([[True, True, False, False]]) - decoder_sequence._keras_mask = mask - outputs = decoder(decoder_sequence, encoder_sequence) - self.assertAllEqual(outputs._keras_mask, mask) + if keras.config.backend() == "mlx": + backend.set_keras_mask(decoder_sequence, mask) + outputs = decoder(decoder_sequence, encoder_sequence) + self.assertAllEqual(backend.get_keras_mask(outputs), mask) + else: + decoder_sequence._keras_mask = mask + outputs = decoder(decoder_sequence, encoder_sequence) + self.assertAllEqual(outputs._keras_mask, mask) def test_mask_propagation_without_cross_attention(self): decoder = TransformerDecoder( @@ -125,9 +132,15 @@ def test_mask_propagation_without_cross_attention(self): ) decoder_sequence = random.uniform(shape=[1, 4, 6]) mask = ops.array([[True, True, False, False]]) - decoder_sequence._keras_mask = mask - outputs = decoder(decoder_sequence) - self.assertAllEqual(outputs._keras_mask, mask) + + if keras.config.backend() == "mlx": + backend.set_keras_mask(decoder_sequence, mask) + outputs = decoder(decoder_sequence) + self.assertAllEqual(backend.get_keras_mask(outputs), mask) + else: + decoder_sequence._keras_mask = mask + outputs = decoder(decoder_sequence) + self.assertAllEqual(outputs._keras_mask, mask) def test_cache_call_is_correct(self): batch_size, seq_len, num_heads, key_dim = 2, 5, 2, 4 diff --git a/keras_hub/src/layers/modeling/transformer_encoder_test.py b/keras_hub/src/layers/modeling/transformer_encoder_test.py index a682af157c..324d74600a 100644 --- a/keras_hub/src/layers/modeling/transformer_encoder_test.py +++ b/keras_hub/src/layers/modeling/transformer_encoder_test.py @@ -2,6 +2,7 @@ from absl.testing import parameterized from keras import ops from keras import random +from keras.src import backend from keras_hub.src.layers.modeling.transformer_encoder import TransformerEncoder from keras_hub.src.tests.test_case import TestCase @@ -92,9 +93,14 @@ def test_mask_propagation(self): ) inputs = random.uniform(shape=[1, 4, 6]) mask = ops.array([[True, True, False, False]]) - inputs._keras_mask = mask - outputs = encoder(inputs) - self.assertAllEqual(outputs._keras_mask, mask) + if keras.config.backend() == "mlx": + backend.set_keras_mask(inputs, mask) + outputs = encoder(inputs) + self.assertAllEqual(backend.get_keras_mask(outputs), mask) + else: + inputs._keras_mask = mask + outputs = encoder(inputs) + self.assertAllEqual(outputs._keras_mask, mask) def test_attention_scores(self): encoder = TransformerEncoder(intermediate_dim=4, num_heads=2) diff --git a/keras_hub/src/layers/preprocessing/image_converter_test.py b/keras_hub/src/layers/preprocessing/image_converter_test.py index 8d47872a43..e380e298c1 100644 --- a/keras_hub/src/layers/preprocessing/image_converter_test.py +++ b/keras_hub/src/layers/preprocessing/image_converter_test.py @@ -52,8 +52,16 @@ def test_unbatched(self): def test_dtypes(self): converter = ImageConverter(image_size=(4, 4), scale=1.0 / 255.0) - int_image = ops.ones((10, 10, 3), dtype="uint8") * 255 - float_image = ops.ones((10, 10, 3), dtype="float64") * 255 + if keras.config.backend() == "mlx": + # mlx backend does not support int matmul + int_image = ops.ones((10, 10, 3), dtype="float16") * 255 + # mlx only suports float64 on the cpu + # can force all operations onto cpu stream with float64 + float_image = ops.ones((10, 10, 3), dtype="float32") * 255 + else: + int_image = ops.ones((10, 10, 3), dtype="uint8") * 255 + float_image = ops.ones((10, 10, 3), dtype="float64") * 255 + self.assertDTypeEqual(converter(int_image), "float32") self.assertDTypeEqual(converter(float_image), "float32") self.assertAllClose(converter(int_image), np.ones((4, 4, 3))) diff --git a/keras_hub/src/tests/test_case.py b/keras_hub/src/tests/test_case.py index 43eb8050c3..ae95f364fe 100644 --- a/keras_hub/src/tests/test_case.py +++ b/keras_hub/src/tests/test_case.py @@ -10,12 +10,14 @@ from keras import ops from keras import tree +# from keras.src.trainers.data_adapters import is_mlx_array from keras_hub.src.layers.modeling.reversible_embedding import ( ReversibleEmbedding, ) from keras_hub.src.models.retinanet.feature_pyramid import FeaturePyramid from keras_hub.src.tokenizers.tokenizer import Tokenizer from keras_hub.src.utils.tensor_utils import is_float_dtype +from keras_hub.src.utils.tensor_utils import is_mlx_array def convert_to_comparible_type(x): @@ -34,6 +36,10 @@ def convert_to_comparible_type(x): return x if hasattr(x, "__array__"): return ops.convert_to_numpy(x) + if keras.config.backend() == "mlx" and is_mlx_array(x): + # this is to handle bfloat16 + # mlx arrays don't have an __array__ attribute + return ops.convert_to_numpy(x) return x diff --git a/keras_hub/src/utils/tensor_utils.py b/keras_hub/src/utils/tensor_utils.py index 47305a3f01..d1dd9c36d5 100644 --- a/keras_hub/src/utils/tensor_utils.py +++ b/keras_hub/src/utils/tensor_utils.py @@ -245,7 +245,14 @@ def tensor_to_list(inputs): def convert_to_ragged_batch(inputs): """Ensure a tf.Tensor is a ragged rank 2 tensor.""" if not isinstance(inputs, (tf.RaggedTensor, tf.Tensor)): - inputs = tf.convert_to_tensor(inputs) + if keras.config.backend() == "mlx": + # mlx array to tf tensor currently only supports flat arrays + array_shape = inputs.shape + inputs = inputs.flatten() + inputs = tf.convert_to_tensor(memoryview(inputs)) + inputs = tf.reshape(inputs, array_shape) + else: + inputs = tf.convert_to_tensor(inputs) unbatched = inputs.shape.rank == 1 rectangular = isinstance(inputs, tf.Tensor) if unbatched: @@ -321,6 +328,15 @@ def is_string_dtype(dtype): return "string" in keras.backend.standardize_dtype(dtype) +def is_mlx_array(value): + if hasattr(value, "__class__"): + return ( + value.__class__.__module__ == "mlx.core" + and value.__class__.__name__ == "array" + ) + return False + + def get_dtype_size_in_bits(dtype): """Get the size of a given dtype in bits.""" dtype = keras.backend.standardize_dtype(dtype) diff --git a/requirements.txt b/requirements.txt index e499522558..ee2077cf41 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,4 +11,9 @@ torchvision>=0.16.0 # Jax. jax[cpu] +# MLX. +pybind11[global] +cmake +mlx + -r requirements-common.txt