Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 78 additions & 1 deletion keras_hub/src/models/gemma3/gemma3_backbone.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import keras
from keras import ops
from keras import ops, layers

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.layers.modeling.reversible_embedding import (
Expand Down Expand Up @@ -424,3 +424,80 @@ def from_config(cls, config):
)

return super().from_config(config)


class MeanPooling(layers.Layer):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move this to gemma3_mean_pooling.py?

"""
This layer calculates the mean of the token embeddings, ignoring
padded tokens.
"""

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.supports_masking = True

def call(self, inputs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of a tuple, can we just pass these arguments to call(...) separately instead of passing them as a tuple?

sequence_output, padding_mask = inputs

mask = ops.expand_dims(
ops.cast(padding_mask, sequence_output.dtype), axis=-1
)
masked_output = sequence_output * mask
sum_embeddings = ops.sum(masked_output, axis=1)

num_tokens = ops.sum(ops.cast(padding_mask, sequence_output.dtype), axis=1)
num_tokens = ops.expand_dims(num_tokens, axis=1)

num_tokens = ops.maximum(num_tokens, 1e-9)

mean_embeddings = sum_embeddings / num_tokens
return mean_embeddings

def compute_output_shape(self, input_shape):
sequence_output_shape, padding_mask_shape = input_shape
return (sequence_output_shape[0], sequence_output_shape[2])


@keras_hub_export("keras_hub.models.Gemma3Embedding")
class Gemma3EmbeddingModel(keras.Model):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should just stick the layers on top in Gemma3Backbone instead of creating a separate class for this.

Can probably add an argument, named is_embedding_model, maybe. If is_embedding_model is True, add the dense layers + mean pooling layer, and return a dictionary, which returns the hidden states and the pooled output.

What do you think?


def __init__(
self,
backbone: Gemma3Backbone,
embedding_dim: int,
**kwargs,
):
self.backbone = backbone
self.pooling_layer = MeanPooling(
dtype=backbone.dtype, name="mean_pooling"
)
self.projection_layer = layers.Dense(
embedding_dim, dtype=backbone.dtype, name="embedding_projection"
Comment on lines +527 to +528
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there just one dense layer for Embedding Gemma? I thought there were several.

)

inputs = self.backbone.input
sequence_output = self.backbone.outputs[0]
padding_mask = inputs["padding_mask"]

pooled_output = self.pooling_layer([sequence_output, padding_mask])

embedding = self.projection_layer(pooled_output)

super().__init__(inputs=inputs, outputs=embedding, **kwargs)

self.embedding_dim = embedding_dim

def get_config(self):
config = super().get_config()
config.update(
{
"backbone": keras.layers.serialize(self.backbone),
"embedding_dim": self.embedding_dim,
}
)
return config

@classmethod
def from_config(cls, config):
config["backbone"] = keras.layers.deserialize(config["backbone"])
return cls(**config)
93 changes: 93 additions & 0 deletions keras_hub/src/models/gemma3/gemma3_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from keras_hub.src.models.gemma3.gemma3_vision_encoder import (
Gemma3VisionEncoder,
)
from keras_hub.src.models.gemma3.gemma3_backbone import (
Gemma3EmbeddingModel,
)
from keras_hub.src.tests.test_case import TestCase


Expand Down Expand Up @@ -193,3 +196,93 @@ def test_all_presets(self):
if "_text" in preset or "1b" in preset
else self.input_data,
)

class Gemma3EmbeddingModelTest(TestCase, parameterized.TestCase):
def setUp(self):
self.batch_size = 2
self.vocabulary_size = 256
self.text_sequence_length = 64
self.hidden_dim = 8
self.embedding_dim = 16

self.backbone = Gemma3Backbone(
vocabulary_size=self.vocabulary_size,
image_size=16,
num_layers=2,
num_query_heads=2,
num_key_value_heads=1,
hidden_dim=self.hidden_dim,
intermediate_dim=32,
head_dim=4,
vision_encoder=None,
)

self.init_kwargs = {
"backbone": self.backbone,
"embedding_dim": self.embedding_dim,
}

dummy_text_token_ids = np.random.randint(
0,
self.vocabulary_size,
(self.batch_size, self.text_sequence_length),
)
padding_mask = np.ones(
(self.batch_size, self.text_sequence_length), dtype="int32"
)
padding_mask[0, -10:] = 0
padding_mask[1, -5:] = 0

self.input_data = {
"token_ids": dummy_text_token_ids,
"padding_mask": padding_mask,
}

def test_model_basics(self):
"""Test the model's forward pass and output shape."""
model = Gemma3EmbeddingModel(**self.init_kwargs)
output = model(self.input_data)
expected_output_shape = (self.batch_size, self.embedding_dim)
self.assertEqual(output.shape, expected_output_shape)

def test_architecture_characteristics(self):
"""Test parameter and layer counts."""
model = Gemma3EmbeddingModel(**self.init_kwargs)

backbone_params = self.backbone.count_params()
projection_params = (
self.hidden_dim * self.embedding_dim
) + self.embedding_dim
expected_params = backbone_params + projection_params

expected_layers = 8

self.assertEqual(model.count_params(), expected_params)
self.assertEqual(len(model.layers), expected_layers)

def test_saved_model(self):
self.run_model_saving_test(
cls=Gemma3EmbeddingModel,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
)

@pytest.mark.kaggle_key_required
@pytest.mark.extra_large
def test_build_from_preset_backbone(self):
backbone = Gemma3Backbone.from_preset("gemma3_instruct_1b_text")
model = Gemma3EmbeddingModel(
backbone=backbone,
embedding_dim=768,
)

input_data = {
"token_ids": ops.array([[651, 4320, 8426, 25341, 235265]]),
"padding_mask": ops.ones((1, 5), dtype="int32"),
}

outputs = model(input_data)

self.assertEqual(outputs.shape, (1, 768))
norm = ops.vector_norm(outputs, axis=1)
self.assertGreater(norm[0], 0)
Loading