diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index fe220e2d43..3fdc420c18 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -285,6 +285,9 @@ from keras_hub.src.models.gemma3.gemma3_backbone import ( Gemma3Backbone as Gemma3Backbone, ) +from keras_hub.src.models.gemma3.gemma3_backbone import ( + Gemma3EmbeddingModel as Gemma3Embedding, +) from keras_hub.src.models.gemma3.gemma3_causal_lm import ( Gemma3CausalLM as Gemma3CausalLM, ) diff --git a/keras_hub/src/models/gemma3/gemma3_backbone.py b/keras_hub/src/models/gemma3/gemma3_backbone.py index a65dbd726b..c882901651 100644 --- a/keras_hub/src/models/gemma3/gemma3_backbone.py +++ b/keras_hub/src/models/gemma3/gemma3_backbone.py @@ -1,4 +1,5 @@ import keras +from keras import layers from keras import ops from keras_hub.src.api_export import keras_hub_export @@ -424,3 +425,129 @@ def from_config(cls, config): ) return super().from_config(config) + + +class MeanPooling(layers.Layer): + """ + Mean pooling layer that computes the average of token embeddings, + respecting a padding mask. + + This layer correctly handles variable-length sequences by ignoring + padded tokens in the mean calculation. + + Call arguments: + inputs: A tuple of `(sequence_output, padding_mask)`. + `sequence_output` is a tensor of shape `(batch_size, seq_len, + hidden_dim)`. `padding_mask` is a tensor of shape `(batch_size, + seq_len)` with `1` for valid tokens and `0` for padded tokens. + + Returns: + A tensor of shape `(batch_size, hidden_dim)`. + + Example: + ```python + sequence_output = np.random.rand(2, 4, 8).astype("float32") + padding_mask = np.array([[1, 1, 1, 0], [1, 1, 0, 0]]) + mean_pool_layer = MeanPooling() + pooled = mean_pool_layer((sequence_output, padding_mask)) + # pooled.shape -> (2, 8) + ``` + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.supports_masking = True + + def call(self, inputs): + sequence_output, padding_mask = inputs + + mask = ops.expand_dims( + ops.cast(padding_mask, sequence_output.dtype), axis=-1 + ) + masked_output = sequence_output * mask + sum_embeddings = ops.sum(masked_output, axis=1) + num_tokens = ops.sum( + ops.cast(padding_mask, sequence_output.dtype), axis=1 + ) + num_tokens = ops.expand_dims(num_tokens, axis=1) + + num_tokens = ops.maximum(num_tokens, 1e-9) + + mean_embeddings = sum_embeddings / num_tokens + return mean_embeddings + + def compute_output_shape(self, input_shape): + sequence_output_shape, padding_mask_shape = input_shape + return (sequence_output_shape[0], sequence_output_shape[2]) + + def get_config(self): + return super().get_config() + + +@keras_hub_export("keras_hub.models.Gemma3Embedding") +class Gemma3EmbeddingModel(keras.Model): + """An end-to-end Gemma3 model for embedding tasks. + + This model takes token ids as input and returns a fixed-size embedding + vector for the input sequence. It uses a `Gemma3Backbone` to generate + contextualized token embeddings, a `MeanPooling` layer to pool them into a + single vector, and a final `Dense` layer to project to the desired + embedding dimension. + + This model can be loaded with a pre-trained `Gemma3Backbone` and used for + tasks like semantic similarity, retrieval, or as a feature extractor. + + Args: + backbone: A `keras_hub.models.Gemma3Backbone` instance. + embedding_dim (int): The dimension of the output embedding. + + Example: + ```python + # backbone = keras_hub.models.Gemma3Backbone.from_preset( + # "gemma3_instruct_1b" + # ) + # embedding_model = keras_hub.models.Gemma3EmbeddingModel( + # backbone=backbone, + # embedding_dim=768, + # ) + # input_data = { + # "token_ids": np.array([[651, 4320, 8426, 25341, 235265]]), + # "padding_mask": np.ones((1, 5), dtype="int32"), + # } + # embeddings = embedding_model.predict(input_data) + ``` + """ + + def __init__(self, backbone, embedding_dim, **kwargs): + super().__init__(**kwargs) + self.backbone = backbone + self.pooling_layer = MeanPooling( + dtype=backbone.dtype, name="mean_pooling" + ) + self.projection_layer = layers.Dense( + embedding_dim, dtype=backbone.dtype, name="embedding_projection" + ) + self.embedding_dim = embedding_dim + + def call(self, inputs): + sequence_output = self.backbone(inputs) + padding_mask = inputs["padding_mask"] + + pooled_output = self.pooling_layer((sequence_output, padding_mask)) + embedding = self.projection_layer(pooled_output) + return embedding + + def get_config(self): + config = super().get_config() + config.update( + { + "backbone": layers.serialize(self.backbone), + "embedding_dim": self.embedding_dim, + } + ) + return config + + @classmethod + def from_config(cls, config): + config["backbone"] = layers.deserialize(config["backbone"]) + return cls(**config) diff --git a/keras_hub/src/models/gemma3/gemma3_backbone_test.py b/keras_hub/src/models/gemma3/gemma3_backbone_test.py index 7eb31f9ff6..8e6f3f31ae 100644 --- a/keras_hub/src/models/gemma3/gemma3_backbone_test.py +++ b/keras_hub/src/models/gemma3/gemma3_backbone_test.py @@ -6,6 +6,7 @@ from keras import ops from keras_hub.src.models.gemma3.gemma3_backbone import Gemma3Backbone +from keras_hub.src.models.gemma3.gemma3_backbone import Gemma3EmbeddingModel from keras_hub.src.models.gemma3.gemma3_vision_encoder import ( Gemma3VisionEncoder, ) @@ -193,3 +194,96 @@ def test_all_presets(self): if "_text" in preset or "1b" in preset else self.input_data, ) + + +class Gemma3EmbeddingModelTest(TestCase, parameterized.TestCase): + def setUp(self): + self.batch_size = 2 + self.vocabulary_size = 256 + self.text_sequence_length = 64 + self.hidden_dim = 8 + self.embedding_dim = 16 + + self.backbone = Gemma3Backbone( + vocabulary_size=self.vocabulary_size, + image_size=16, + num_layers=2, + num_query_heads=2, + num_key_value_heads=1, + hidden_dim=self.hidden_dim, + intermediate_dim=32, + head_dim=4, + vision_encoder=None, + ) + + self.init_kwargs = { + "backbone": self.backbone, + "embedding_dim": self.embedding_dim, + } + + dummy_text_token_ids = np.random.randint( + 0, + self.vocabulary_size, + (self.batch_size, self.text_sequence_length), + ) + padding_mask = np.ones( + (self.batch_size, self.text_sequence_length), dtype="int32" + ) + padding_mask[0, -10:] = 0 + padding_mask[1, -5:] = 0 + + self.input_data = { + "token_ids": dummy_text_token_ids, + "padding_mask": padding_mask, + } + + def test_model_basics(self): + """Test the model's forward pass and output shape.""" + model = Gemma3EmbeddingModel(**self.init_kwargs) + output = model(self.input_data) + expected_output_shape = (self.batch_size, self.embedding_dim) + self.assertEqual(output.shape, expected_output_shape) + + def test_architecture_characteristics(self): + """Test parameter and layer counts.""" + model = Gemma3EmbeddingModel(**self.init_kwargs) + + model(self.input_data) + + backbone_params = self.backbone.count_params() + projection_params = ( + self.hidden_dim * self.embedding_dim + ) + self.embedding_dim + expected_params = backbone_params + projection_params + + expected_layers = 3 + + self.assertEqual(model.count_params(), expected_params) + self.assertEqual(len(model.layers), expected_layers) + + def test_saved_model(self): + self.run_model_saving_test( + cls=Gemma3EmbeddingModel, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + @pytest.mark.kaggle_key_required + @pytest.mark.extra_large + def test_build_from_preset_backbone(self): + backbone = Gemma3Backbone.from_preset("gemma3_instruct_1b_text") + model = Gemma3EmbeddingModel( + backbone=backbone, + embedding_dim=768, + ) + + input_data = { + "token_ids": ops.array([[651, 4320, 8426, 25341, 235265]]), + "padding_mask": ops.ones((1, 5), dtype="int32"), + } + + outputs = model(input_data) + + self.assertEqual(outputs.shape, (1, 768)) + norm = ops.vector_norm(outputs, axis=1) + self.assertGreater(norm[0], 0)