-
Notifications
You must be signed in to change notification settings - Fork 301
Changes to the Gemma3 backbone for Embedding Gemma model #2428
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
import keras | ||
from keras import ops | ||
from keras import ops, layers | ||
|
||
from keras_hub.src.api_export import keras_hub_export | ||
from keras_hub.src.layers.modeling.reversible_embedding import ( | ||
|
@@ -424,3 +424,80 @@ def from_config(cls, config): | |
) | ||
|
||
return super().from_config(config) | ||
|
||
|
||
class MeanPooling(layers.Layer): | ||
""" | ||
This layer calculates the mean of the token embeddings, ignoring | ||
padded tokens. | ||
""" | ||
|
||
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
self.supports_masking = True | ||
|
||
def call(self, inputs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of a tuple, can we just pass these arguments to |
||
sequence_output, padding_mask = inputs | ||
|
||
mask = ops.expand_dims( | ||
ops.cast(padding_mask, sequence_output.dtype), axis=-1 | ||
) | ||
masked_output = sequence_output * mask | ||
sum_embeddings = ops.sum(masked_output, axis=1) | ||
|
||
num_tokens = ops.sum(ops.cast(padding_mask, sequence_output.dtype), axis=1) | ||
num_tokens = ops.expand_dims(num_tokens, axis=1) | ||
|
||
num_tokens = ops.maximum(num_tokens, 1e-9) | ||
|
||
mean_embeddings = sum_embeddings / num_tokens | ||
return mean_embeddings | ||
|
||
def compute_output_shape(self, input_shape): | ||
sequence_output_shape, padding_mask_shape = input_shape | ||
return (sequence_output_shape[0], sequence_output_shape[2]) | ||
|
||
buildwithsuhana marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
@keras_hub_export("keras_hub.models.Gemma3Embedding") | ||
class Gemma3EmbeddingModel(keras.Model): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should just stick the layers on top in Can probably add an argument, named What do you think? |
||
|
||
def __init__( | ||
self, | ||
backbone: Gemma3Backbone, | ||
embedding_dim: int, | ||
**kwargs, | ||
): | ||
self.backbone = backbone | ||
self.pooling_layer = MeanPooling( | ||
dtype=backbone.dtype, name="mean_pooling" | ||
) | ||
self.projection_layer = layers.Dense( | ||
embedding_dim, dtype=backbone.dtype, name="embedding_projection" | ||
Comment on lines
+527
to
+528
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there just one dense layer for Embedding Gemma? I thought there were several. |
||
) | ||
|
||
inputs = self.backbone.input | ||
sequence_output = self.backbone.outputs[0] | ||
padding_mask = inputs["padding_mask"] | ||
|
||
pooled_output = self.pooling_layer([sequence_output, padding_mask]) | ||
|
||
embedding = self.projection_layer(pooled_output) | ||
|
||
super().__init__(inputs=inputs, outputs=embedding, **kwargs) | ||
|
||
self.embedding_dim = embedding_dim | ||
|
||
buildwithsuhana marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"backbone": keras.layers.serialize(self.backbone), | ||
"embedding_dim": self.embedding_dim, | ||
} | ||
) | ||
return config | ||
|
||
@classmethod | ||
def from_config(cls, config): | ||
config["backbone"] = keras.layers.deserialize(config["backbone"]) | ||
return cls(**config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's move this to
gemma3_mean_pooling.py
?