-
Notifications
You must be signed in to change notification settings - Fork 301
[MLX backend] Support for MLX backend across layers
tests
#2351
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,8 @@ | ||
import keras | ||
from absl.testing import parameterized | ||
from keras import ops | ||
from keras import random | ||
from keras.src import backend | ||
|
||
from keras_hub.src.layers.modeling.transformer_decoder import TransformerDecoder | ||
from keras_hub.src.tests.test_case import TestCase | ||
|
@@ -114,9 +116,14 @@ def test_mask_propagation(self): | |
decoder_sequence = random.uniform(shape=[1, 4, 6]) | ||
encoder_sequence = random.uniform(shape=[1, 4, 6]) | ||
mask = ops.array([[True, True, False, False]]) | ||
decoder_sequence._keras_mask = mask | ||
outputs = decoder(decoder_sequence, encoder_sequence) | ||
self.assertAllEqual(outputs._keras_mask, mask) | ||
if keras.config.backend() == "mlx": | ||
backend.set_keras_mask(decoder_sequence, mask) | ||
outputs = decoder(decoder_sequence, encoder_sequence) | ||
self.assertAllEqual(backend.get_keras_mask(outputs), mask) | ||
else: | ||
decoder_sequence._keras_mask = mask | ||
outputs = decoder(decoder_sequence, encoder_sequence) | ||
self.assertAllEqual(outputs._keras_mask, mask) | ||
Comment on lines
+119
to
+126
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This block contains duplicated code. You can refactor it by setting the mask first, then calling the decoder, and finally asserting the output mask. This will make the code more concise and easier to maintain. if keras.config.backend() == "mlx":
backend.set_keras_mask(decoder_sequence, mask)
else:
decoder_sequence._keras_mask = mask
outputs = decoder(decoder_sequence, encoder_sequence)
output_mask = (
backend.get_keras_mask(outputs)
if keras.config.backend() == "mlx"
else outputs._keras_mask
)
self.assertAllEqual(output_mask, mask) |
||
|
||
def test_mask_propagation_without_cross_attention(self): | ||
decoder = TransformerDecoder( | ||
|
@@ -125,9 +132,15 @@ def test_mask_propagation_without_cross_attention(self): | |
) | ||
decoder_sequence = random.uniform(shape=[1, 4, 6]) | ||
mask = ops.array([[True, True, False, False]]) | ||
decoder_sequence._keras_mask = mask | ||
outputs = decoder(decoder_sequence) | ||
self.assertAllEqual(outputs._keras_mask, mask) | ||
|
||
if keras.config.backend() == "mlx": | ||
backend.set_keras_mask(decoder_sequence, mask) | ||
outputs = decoder(decoder_sequence) | ||
self.assertAllEqual(backend.get_keras_mask(outputs), mask) | ||
else: | ||
decoder_sequence._keras_mask = mask | ||
outputs = decoder(decoder_sequence) | ||
self.assertAllEqual(outputs._keras_mask, mask) | ||
Comment on lines
+136
to
+143
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This block has duplicated code, similar to the previous test. Refactoring it will improve code quality and maintainability. if keras.config.backend() == "mlx":
backend.set_keras_mask(decoder_sequence, mask)
else:
decoder_sequence._keras_mask = mask
outputs = decoder(decoder_sequence)
output_mask = (
backend.get_keras_mask(outputs)
if keras.config.backend() == "mlx"
else outputs._keras_mask
)
self.assertAllEqual(output_mask, mask) |
||
|
||
def test_cache_call_is_correct(self): | ||
batch_size, seq_len, num_heads, key_dim = 2, 5, 2, 4 | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -2,6 +2,7 @@ | |||||||||||||||||||||||||||||||||||||||
from absl.testing import parameterized | ||||||||||||||||||||||||||||||||||||||||
from keras import ops | ||||||||||||||||||||||||||||||||||||||||
from keras import random | ||||||||||||||||||||||||||||||||||||||||
from keras.src import backend | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
from keras_hub.src.layers.modeling.transformer_encoder import TransformerEncoder | ||||||||||||||||||||||||||||||||||||||||
from keras_hub.src.tests.test_case import TestCase | ||||||||||||||||||||||||||||||||||||||||
|
@@ -92,9 +93,14 @@ def test_mask_propagation(self): | |||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||
inputs = random.uniform(shape=[1, 4, 6]) | ||||||||||||||||||||||||||||||||||||||||
mask = ops.array([[True, True, False, False]]) | ||||||||||||||||||||||||||||||||||||||||
inputs._keras_mask = mask | ||||||||||||||||||||||||||||||||||||||||
outputs = encoder(inputs) | ||||||||||||||||||||||||||||||||||||||||
self.assertAllEqual(outputs._keras_mask, mask) | ||||||||||||||||||||||||||||||||||||||||
if keras.config.backend() == "mlx": | ||||||||||||||||||||||||||||||||||||||||
backend.set_keras_mask(inputs, mask) | ||||||||||||||||||||||||||||||||||||||||
outputs = encoder(inputs) | ||||||||||||||||||||||||||||||||||||||||
self.assertAllEqual(backend.get_keras_mask(outputs), mask) | ||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||
inputs._keras_mask = mask | ||||||||||||||||||||||||||||||||||||||||
outputs = encoder(inputs) | ||||||||||||||||||||||||||||||||||||||||
self.assertAllEqual(outputs._keras_mask, mask) | ||||||||||||||||||||||||||||||||||||||||
Comment on lines
+96
to
+103
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This block contains duplicated code. You can refactor it to improve readability and maintainability by setting the mask first, then calling the encoder, and finally asserting the output mask.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
def test_attention_scores(self): | ||||||||||||||||||||||||||||||||||||||||
encoder = TransformerEncoder(intermediate_dim=4, num_heads=2) | ||||||||||||||||||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,12 +10,14 @@ | |
from keras import ops | ||
from keras import tree | ||
|
||
# from keras.src.trainers.data_adapters import is_mlx_array | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
from keras_hub.src.layers.modeling.reversible_embedding import ( | ||
ReversibleEmbedding, | ||
) | ||
from keras_hub.src.models.retinanet.feature_pyramid import FeaturePyramid | ||
from keras_hub.src.tokenizers.tokenizer import Tokenizer | ||
from keras_hub.src.utils.tensor_utils import is_float_dtype | ||
from keras_hub.src.utils.tensor_utils import is_mlx_array | ||
|
||
|
||
def convert_to_comparible_type(x): | ||
|
@@ -34,6 +36,10 @@ def convert_to_comparible_type(x): | |
return x | ||
if hasattr(x, "__array__"): | ||
return ops.convert_to_numpy(x) | ||
if keras.config.backend() == "mlx" and is_mlx_array(x): | ||
# this is to handle bfloat16 | ||
# mlx arrays don't have an __array__ attribute | ||
return ops.convert_to_numpy(x) | ||
return x | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -245,7 +245,14 @@ def tensor_to_list(inputs): | |||||
def convert_to_ragged_batch(inputs): | ||||||
"""Ensure a tf.Tensor is a ragged rank 2 tensor.""" | ||||||
if not isinstance(inputs, (tf.RaggedTensor, tf.Tensor)): | ||||||
inputs = tf.convert_to_tensor(inputs) | ||||||
if keras.config.backend() == "mlx": | ||||||
# mlx array to tf tensor currently only supports flat arrays | ||||||
array_shape = inputs.shape | ||||||
inputs = inputs.flatten() | ||||||
inputs = tf.convert_to_tensor(memoryview(inputs)) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using
Suggested change
|
||||||
inputs = tf.reshape(inputs, array_shape) | ||||||
else: | ||||||
inputs = tf.convert_to_tensor(inputs) | ||||||
unbatched = inputs.shape.rank == 1 | ||||||
rectangular = isinstance(inputs, tf.Tensor) | ||||||
if unbatched: | ||||||
|
@@ -321,6 +328,15 @@ def is_string_dtype(dtype): | |||||
return "string" in keras.backend.standardize_dtype(dtype) | ||||||
|
||||||
|
||||||
def is_mlx_array(value): | ||||||
if hasattr(value, "__class__"): | ||||||
return ( | ||||||
value.__class__.__module__ == "mlx.core" | ||||||
and value.__class__.__name__ == "array" | ||||||
) | ||||||
return False | ||||||
|
||||||
|
||||||
def get_dtype_size_in_bits(dtype): | ||||||
"""Get the size of a given dtype in bits.""" | ||||||
dtype = keras.backend.standardize_dtype(dtype) | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,4 +11,9 @@ torchvision>=0.16.0 | |
# Jax. | ||
jax[cpu] | ||
|
||
# MLX. | ||
pybind11[global] | ||
cmake | ||
mlx | ||
|
||
-r requirements-common.txt |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To improve readability and reduce code duplication, you can determine the output mask based on the backend and then perform the assertion once.