Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions keras_hub/src/layers/modeling/reversible_embedding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import keras
import numpy as np
import pytest
from absl.testing import parameterized
from keras import ops
from keras import random
Expand Down Expand Up @@ -95,6 +96,10 @@ def test_reverse_dtype(self):
@parameterized.named_parameters(
("tie_weights", True), ("untie_weights", False)
)
@pytest.mark.skipif(
keras.config.backend() == "mlx",
reason="quantization not yet implemented for mlx backend",
)
def test_quantize_int8(self, tie_weights):
layer_config = dict(
input_dim=100, output_dim=32, tie_weights=tie_weights
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import keras
import numpy as np
from keras import ops
from keras import random
from keras.src import backend

from keras_hub.src.layers.modeling.token_and_position_embedding import (
TokenAndPositionEmbedding,
Expand Down Expand Up @@ -34,4 +36,7 @@ def test_mask_propagation(self):
input_data = np.array([[1, 0], [1, 0]])
mask = input_data != 0
outputs = test_layer(input_data)
self.assertAllEqual(outputs._keras_mask, mask)
if keras.config.backend() == "mlx":
self.assertAllEqual(backend.get_keras_mask(outputs), mask)
else:
self.assertAllEqual(outputs._keras_mask, mask)
Comment on lines +39 to +42
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve readability and reduce code duplication, you can determine the output mask based on the backend and then perform the assertion once.

Suggested change
if keras.config.backend() == "mlx":
self.assertAllEqual(backend.get_keras_mask(outputs), mask)
else:
self.assertAllEqual(outputs._keras_mask, mask)
self.assertAllEqual(
backend.get_keras_mask(outputs)
if keras.config.backend() == "mlx"
else outputs._keras_mask,
mask,
)

25 changes: 19 additions & 6 deletions keras_hub/src/layers/modeling/transformer_decoder_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import keras
from absl.testing import parameterized
from keras import ops
from keras import random
from keras.src import backend

from keras_hub.src.layers.modeling.transformer_decoder import TransformerDecoder
from keras_hub.src.tests.test_case import TestCase
Expand Down Expand Up @@ -114,9 +116,14 @@ def test_mask_propagation(self):
decoder_sequence = random.uniform(shape=[1, 4, 6])
encoder_sequence = random.uniform(shape=[1, 4, 6])
mask = ops.array([[True, True, False, False]])
decoder_sequence._keras_mask = mask
outputs = decoder(decoder_sequence, encoder_sequence)
self.assertAllEqual(outputs._keras_mask, mask)
if keras.config.backend() == "mlx":
backend.set_keras_mask(decoder_sequence, mask)
outputs = decoder(decoder_sequence, encoder_sequence)
self.assertAllEqual(backend.get_keras_mask(outputs), mask)
else:
decoder_sequence._keras_mask = mask
outputs = decoder(decoder_sequence, encoder_sequence)
self.assertAllEqual(outputs._keras_mask, mask)
Comment on lines +119 to +126
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block contains duplicated code. You can refactor it by setting the mask first, then calling the decoder, and finally asserting the output mask. This will make the code more concise and easier to maintain.

        if keras.config.backend() == "mlx":
            backend.set_keras_mask(decoder_sequence, mask)
        else:
            decoder_sequence._keras_mask = mask
        outputs = decoder(decoder_sequence, encoder_sequence)
        output_mask = (
            backend.get_keras_mask(outputs)
            if keras.config.backend() == "mlx"
            else outputs._keras_mask
        )
        self.assertAllEqual(output_mask, mask)


def test_mask_propagation_without_cross_attention(self):
decoder = TransformerDecoder(
Expand All @@ -125,9 +132,15 @@ def test_mask_propagation_without_cross_attention(self):
)
decoder_sequence = random.uniform(shape=[1, 4, 6])
mask = ops.array([[True, True, False, False]])
decoder_sequence._keras_mask = mask
outputs = decoder(decoder_sequence)
self.assertAllEqual(outputs._keras_mask, mask)

if keras.config.backend() == "mlx":
backend.set_keras_mask(decoder_sequence, mask)
outputs = decoder(decoder_sequence)
self.assertAllEqual(backend.get_keras_mask(outputs), mask)
else:
decoder_sequence._keras_mask = mask
outputs = decoder(decoder_sequence)
self.assertAllEqual(outputs._keras_mask, mask)
Comment on lines +136 to +143
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block has duplicated code, similar to the previous test. Refactoring it will improve code quality and maintainability.

        if keras.config.backend() == "mlx":
            backend.set_keras_mask(decoder_sequence, mask)
        else:
            decoder_sequence._keras_mask = mask
        outputs = decoder(decoder_sequence)
        output_mask = (
            backend.get_keras_mask(outputs)
            if keras.config.backend() == "mlx"
            else outputs._keras_mask
        )
        self.assertAllEqual(output_mask, mask)


def test_cache_call_is_correct(self):
batch_size, seq_len, num_heads, key_dim = 2, 5, 2, 4
Expand Down
12 changes: 9 additions & 3 deletions keras_hub/src/layers/modeling/transformer_encoder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from absl.testing import parameterized
from keras import ops
from keras import random
from keras.src import backend

from keras_hub.src.layers.modeling.transformer_encoder import TransformerEncoder
from keras_hub.src.tests.test_case import TestCase
Expand Down Expand Up @@ -92,9 +93,14 @@ def test_mask_propagation(self):
)
inputs = random.uniform(shape=[1, 4, 6])
mask = ops.array([[True, True, False, False]])
inputs._keras_mask = mask
outputs = encoder(inputs)
self.assertAllEqual(outputs._keras_mask, mask)
if keras.config.backend() == "mlx":
backend.set_keras_mask(inputs, mask)
outputs = encoder(inputs)
self.assertAllEqual(backend.get_keras_mask(outputs), mask)
else:
inputs._keras_mask = mask
outputs = encoder(inputs)
self.assertAllEqual(outputs._keras_mask, mask)
Comment on lines +96 to +103
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block contains duplicated code. You can refactor it to improve readability and maintainability by setting the mask first, then calling the encoder, and finally asserting the output mask.

Suggested change
if keras.config.backend() == "mlx":
backend.set_keras_mask(inputs, mask)
outputs = encoder(inputs)
self.assertAllEqual(backend.get_keras_mask(outputs), mask)
else:
inputs._keras_mask = mask
outputs = encoder(inputs)
self.assertAllEqual(outputs._keras_mask, mask)
if keras.config.backend() == "mlx":
backend.set_keras_mask(inputs, mask)
else:
inputs._keras_mask = mask
outputs = encoder(inputs)
output_mask = (
backend.get_keras_mask(outputs)
if keras.config.backend() == "mlx"
else outputs._keras_mask
)
self.assertAllEqual(output_mask, mask)


def test_attention_scores(self):
encoder = TransformerEncoder(intermediate_dim=4, num_heads=2)
Expand Down
12 changes: 10 additions & 2 deletions keras_hub/src/layers/preprocessing/image_converter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,16 @@ def test_unbatched(self):

def test_dtypes(self):
converter = ImageConverter(image_size=(4, 4), scale=1.0 / 255.0)
int_image = ops.ones((10, 10, 3), dtype="uint8") * 255
float_image = ops.ones((10, 10, 3), dtype="float64") * 255
if keras.config.backend() == "mlx":
# mlx backend does not support int matmul
int_image = ops.ones((10, 10, 3), dtype="float16") * 255
# mlx only suports float64 on the cpu
# can force all operations onto cpu stream with float64
float_image = ops.ones((10, 10, 3), dtype="float32") * 255
else:
int_image = ops.ones((10, 10, 3), dtype="uint8") * 255
float_image = ops.ones((10, 10, 3), dtype="float64") * 255

self.assertDTypeEqual(converter(int_image), "float32")
self.assertDTypeEqual(converter(float_image), "float32")
self.assertAllClose(converter(int_image), np.ones((4, 4, 3)))
Expand Down
6 changes: 6 additions & 0 deletions keras_hub/src/tests/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@
from keras import ops
from keras import tree

# from keras.src.trainers.data_adapters import is_mlx_array
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This commented-out import appears to be a leftover and should be removed to keep the code clean.

from keras_hub.src.layers.modeling.reversible_embedding import (
ReversibleEmbedding,
)
from keras_hub.src.models.retinanet.feature_pyramid import FeaturePyramid
from keras_hub.src.tokenizers.tokenizer import Tokenizer
from keras_hub.src.utils.tensor_utils import is_float_dtype
from keras_hub.src.utils.tensor_utils import is_mlx_array


def convert_to_comparible_type(x):
Expand All @@ -34,6 +36,10 @@ def convert_to_comparible_type(x):
return x
if hasattr(x, "__array__"):
return ops.convert_to_numpy(x)
if keras.config.backend() == "mlx" and is_mlx_array(x):
# this is to handle bfloat16
# mlx arrays don't have an __array__ attribute
return ops.convert_to_numpy(x)
return x


Expand Down
18 changes: 17 additions & 1 deletion keras_hub/src/utils/tensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,14 @@ def tensor_to_list(inputs):
def convert_to_ragged_batch(inputs):
"""Ensure a tf.Tensor is a ragged rank 2 tensor."""
if not isinstance(inputs, (tf.RaggedTensor, tf.Tensor)):
inputs = tf.convert_to_tensor(inputs)
if keras.config.backend() == "mlx":
# mlx array to tf tensor currently only supports flat arrays
array_shape = inputs.shape
inputs = inputs.flatten()
inputs = tf.convert_to_tensor(memoryview(inputs))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Using memoryview(inputs) will raise a TypeError because MLX arrays do not support the buffer protocol. You should convert the MLX array to a NumPy array before passing it to tf.convert_to_tensor.

Suggested change
inputs = tf.convert_to_tensor(memoryview(inputs))
inputs = tf.convert_to_tensor(np.array(inputs))

inputs = tf.reshape(inputs, array_shape)
else:
inputs = tf.convert_to_tensor(inputs)
unbatched = inputs.shape.rank == 1
rectangular = isinstance(inputs, tf.Tensor)
if unbatched:
Expand Down Expand Up @@ -321,6 +328,15 @@ def is_string_dtype(dtype):
return "string" in keras.backend.standardize_dtype(dtype)


def is_mlx_array(value):
if hasattr(value, "__class__"):
return (
value.__class__.__module__ == "mlx.core"
and value.__class__.__name__ == "array"
)
return False


def get_dtype_size_in_bits(dtype):
"""Get the size of a given dtype in bits."""
dtype = keras.backend.standardize_dtype(dtype)
Expand Down
5 changes: 5 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,9 @@ torchvision>=0.16.0
# Jax.
jax[cpu]

# MLX.
pybind11[global]
cmake
mlx

-r requirements-common.txt
Loading