Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions keras_hub/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,9 @@
from keras_hub.src.models.gemma3.gemma3_backbone import (
Gemma3Backbone as Gemma3Backbone,
)
from keras_hub.src.models.gemma3.gemma3_backbone import (
Gemma3EmbeddingModel as Gemma3Embedding,
)
from keras_hub.src.models.gemma3.gemma3_causal_lm import (
Gemma3CausalLM as Gemma3CausalLM,
)
Expand Down
127 changes: 127 additions & 0 deletions keras_hub/src/models/gemma3/gemma3_backbone.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import keras
from keras import layers
from keras import ops

from keras_hub.src.api_export import keras_hub_export
Expand Down Expand Up @@ -424,3 +425,129 @@ def from_config(cls, config):
)

return super().from_config(config)


class MeanPooling(layers.Layer):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move this to gemma3_mean_pooling.py?

"""
Mean pooling layer that computes the average of token embeddings,
respecting a padding mask.

This layer correctly handles variable-length sequences by ignoring
padded tokens in the mean calculation.

Call arguments:
inputs: A tuple of `(sequence_output, padding_mask)`.
`sequence_output` is a tensor of shape `(batch_size, seq_len,
hidden_dim)`. `padding_mask` is a tensor of shape `(batch_size,
seq_len)` with `1` for valid tokens and `0` for padded tokens.

Returns:
A tensor of shape `(batch_size, hidden_dim)`.

Example:
```python
sequence_output = np.random.rand(2, 4, 8).astype("float32")
padding_mask = np.array([[1, 1, 1, 0], [1, 1, 0, 0]])
mean_pool_layer = MeanPooling()
pooled = mean_pool_layer((sequence_output, padding_mask))
# pooled.shape -> (2, 8)
```
"""

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.supports_masking = True

def call(self, inputs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of a tuple, can we just pass these arguments to call(...) separately instead of passing them as a tuple?

sequence_output, padding_mask = inputs

mask = ops.expand_dims(
ops.cast(padding_mask, sequence_output.dtype), axis=-1
)
masked_output = sequence_output * mask
sum_embeddings = ops.sum(masked_output, axis=1)
num_tokens = ops.sum(
ops.cast(padding_mask, sequence_output.dtype), axis=1
)
num_tokens = ops.expand_dims(num_tokens, axis=1)

num_tokens = ops.maximum(num_tokens, 1e-9)

mean_embeddings = sum_embeddings / num_tokens
return mean_embeddings

def compute_output_shape(self, input_shape):
sequence_output_shape, padding_mask_shape = input_shape
return (sequence_output_shape[0], sequence_output_shape[2])

def get_config(self):
return super().get_config()


@keras_hub_export("keras_hub.models.Gemma3Embedding")
class Gemma3EmbeddingModel(keras.Model):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should just stick the layers on top in Gemma3Backbone instead of creating a separate class for this.

Can probably add an argument, named is_embedding_model, maybe. If is_embedding_model is True, add the dense layers + mean pooling layer, and return a dictionary, which returns the hidden states and the pooled output.

What do you think?

"""An end-to-end Gemma3 model for embedding tasks.

This model takes token ids as input and returns a fixed-size embedding
vector for the input sequence. It uses a `Gemma3Backbone` to generate
contextualized token embeddings, a `MeanPooling` layer to pool them into a
single vector, and a final `Dense` layer to project to the desired
embedding dimension.

This model can be loaded with a pre-trained `Gemma3Backbone` and used for
tasks like semantic similarity, retrieval, or as a feature extractor.

Args:
backbone: A `keras_hub.models.Gemma3Backbone` instance.
embedding_dim (int): The dimension of the output embedding.

Example:
```python
# backbone = keras_hub.models.Gemma3Backbone.from_preset(
# "gemma3_instruct_1b"
# )
# embedding_model = keras_hub.models.Gemma3EmbeddingModel(
# backbone=backbone,
# embedding_dim=768,
# )
# input_data = {
# "token_ids": np.array([[651, 4320, 8426, 25341, 235265]]),
# "padding_mask": np.ones((1, 5), dtype="int32"),
# }
# embeddings = embedding_model.predict(input_data)
```
"""

def __init__(self, backbone, embedding_dim, **kwargs):
super().__init__(**kwargs)
self.backbone = backbone
self.pooling_layer = MeanPooling(
dtype=backbone.dtype, name="mean_pooling"
)
self.projection_layer = layers.Dense(
embedding_dim, dtype=backbone.dtype, name="embedding_projection"
Comment on lines +527 to +528
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there just one dense layer for Embedding Gemma? I thought there were several.

)
self.embedding_dim = embedding_dim

def call(self, inputs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't generally have call(...) for KerasHub models because these models are Functional models.

sequence_output = self.backbone(inputs)
padding_mask = inputs["padding_mask"]

pooled_output = self.pooling_layer((sequence_output, padding_mask))
embedding = self.projection_layer(pooled_output)
return embedding

def get_config(self):
config = super().get_config()
config.update(
{
"backbone": layers.serialize(self.backbone),
"embedding_dim": self.embedding_dim,
}
)
return config

@classmethod
def from_config(cls, config):
config["backbone"] = layers.deserialize(config["backbone"])
return cls(**config)
94 changes: 94 additions & 0 deletions keras_hub/src/models/gemma3/gemma3_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from keras import ops

from keras_hub.src.models.gemma3.gemma3_backbone import Gemma3Backbone
from keras_hub.src.models.gemma3.gemma3_backbone import Gemma3EmbeddingModel
from keras_hub.src.models.gemma3.gemma3_vision_encoder import (
Gemma3VisionEncoder,
)
Expand Down Expand Up @@ -193,3 +194,96 @@ def test_all_presets(self):
if "_text" in preset or "1b" in preset
else self.input_data,
)


class Gemma3EmbeddingModelTest(TestCase, parameterized.TestCase):
def setUp(self):
self.batch_size = 2
self.vocabulary_size = 256
self.text_sequence_length = 64
self.hidden_dim = 8
self.embedding_dim = 16

self.backbone = Gemma3Backbone(
vocabulary_size=self.vocabulary_size,
image_size=16,
num_layers=2,
num_query_heads=2,
num_key_value_heads=1,
hidden_dim=self.hidden_dim,
intermediate_dim=32,
head_dim=4,
vision_encoder=None,
)

self.init_kwargs = {
"backbone": self.backbone,
"embedding_dim": self.embedding_dim,
}

dummy_text_token_ids = np.random.randint(
0,
self.vocabulary_size,
(self.batch_size, self.text_sequence_length),
)
padding_mask = np.ones(
(self.batch_size, self.text_sequence_length), dtype="int32"
)
padding_mask[0, -10:] = 0
padding_mask[1, -5:] = 0

self.input_data = {
"token_ids": dummy_text_token_ids,
"padding_mask": padding_mask,
}

def test_model_basics(self):
"""Test the model's forward pass and output shape."""
model = Gemma3EmbeddingModel(**self.init_kwargs)
output = model(self.input_data)
expected_output_shape = (self.batch_size, self.embedding_dim)
self.assertEqual(output.shape, expected_output_shape)

def test_architecture_characteristics(self):
"""Test parameter and layer counts."""
model = Gemma3EmbeddingModel(**self.init_kwargs)

model(self.input_data)

backbone_params = self.backbone.count_params()
projection_params = (
self.hidden_dim * self.embedding_dim
) + self.embedding_dim
expected_params = backbone_params + projection_params

expected_layers = 3

self.assertEqual(model.count_params(), expected_params)
self.assertEqual(len(model.layers), expected_layers)

def test_saved_model(self):
self.run_model_saving_test(
cls=Gemma3EmbeddingModel,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
)

@pytest.mark.kaggle_key_required
@pytest.mark.extra_large
def test_build_from_preset_backbone(self):
backbone = Gemma3Backbone.from_preset("gemma3_instruct_1b_text")
model = Gemma3EmbeddingModel(
backbone=backbone,
embedding_dim=768,
)

input_data = {
"token_ids": ops.array([[651, 4320, 8426, 25341, 235265]]),
"padding_mask": ops.ones((1, 5), dtype="int32"),
}

outputs = model(input_data)

self.assertEqual(outputs.shape, (1, 768))
norm = ops.vector_norm(outputs, axis=1)
self.assertGreater(norm[0], 0)
Loading