Skip to content

Commit 5f82baf

Browse files
committed
Update EmbeddingGemma to support optional final L2 normalization
This commit introduces an optional final_l2_norm parameter to the EmbeddingGemma model, allowing users to enable or disable the final L2 normalization step. This enhancement is particularly useful for advanced use cases such as Matryoshka embeddings. Changes: - Model Update: Modified EmbeddingGemma in embedding_gemma.py to accept a final_l2_norm argument in __init__ and build_model. The forward method now conditionally applies L2 normalization. - Conversion Script: Updated convert_to_tflite.py to include a --final_l2_norm flag (defaulting to True), passing it to the model builder. - Verification: Updated verify.py and verify_util.py to support the final_l2_norm flag, ensuring the verification process matches the model configuration. - Documentation: Updated README.md to document the new flag, mention Matryoshka embeddings, and include litert-torch in the requirements. The default behavior remains unchanged (final_l2_norm=True), ensuring backward compatibility.
1 parent 4c84cc3 commit 5f82baf

5 files changed

Lines changed: 88 additions & 17 deletions

File tree

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Embedding Gemma-300M
2+
3+
[Embedding Gemma](https://huggingface.co/google/embeddinggemma-300m) is a text embedding model based on the Gemma architecture. This example demonstrates how to reauthor the model using the LiteRT Torch Generative API and convert it to TFLite.
4+
5+
## Model Details
6+
7+
EmbeddingGemma-300M is an encoder-only model with 24 layers. It uses a combination of local sliding window attention and global attention. This implementation supports:
8+
- Sliding window attention mask.
9+
- Mean pooling of hidden states.
10+
- Final dense projections and L2 normalization (optional, useful for Matryoshka embeddings).
11+
12+
## Requirements
13+
14+
To run this example and verify the results, you need the following packages:
15+
16+
```bash
17+
pip install litert-torch transformers sentence-transformers safetensors
18+
```
19+
20+
## Convert to TFLite
21+
22+
To convert the model to TFLite, use the `convert_to_tflite.py` script. You'll need the HuggingFace checkpoint for `google/embeddinggemma-300m`.
23+
24+
```bash
25+
python convert_to_tflite.py
26+
--checkpoint_path=<path_to_checkpoint>
27+
--output_path=/tmp/
28+
--quantize=dynamic_int8
29+
--prefill_seq_lens=512
30+
--final_l2_norm=True
31+
```
32+
33+
### Conversion Flags
34+
35+
- `--checkpoint_path`: Path to the directory containing the model's `model.safetensors` and Dense projection layers.
36+
- `--output_path`: Directory where the converted `.tflite` model will be saved.
37+
- `--quantize`: Quantization scheme (e.g., `dynamic_int8`, `none`).
38+
- `--prefill_seq_lens`: Defines the input sequence length for the converted TFLite model.
39+
- `--final_l2_norm`: Whether to apply final L2 normalization to the embeddings. Defaults to `True`. Set to `False` if using Matryoshka embeddings.
40+
41+
## Verify the Model
42+
43+
You can verify the reauthored model's output against the original HuggingFace model using `verify.py`.
44+
45+
```bash
46+
python verify.py
47+
--checkpoint=<path_to_checkpoint>
48+
--prompts="What is the meaning of life?"
49+
--prompts="This is an example sentence."
50+
```
51+
52+
The verification script compares the final embeddings produced by the original `sentence-transformers` implementation and the reauthored `litert_torch` implementation to ensure parity.

litert_torch/generative/examples/embedding_gemma/convert_to_tflite.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,18 @@
2525
flags = generative_converter.define_conversion_flags(
2626
model_name="embedding_gemma"
2727
)
28+
flags.DEFINE_bool(
29+
"final_l2_norm",
30+
True,
31+
"Whether to apply final L2 normalization to the embeddings.",
32+
)
2833
FLAGS = flags.FLAGS
2934

3035

3136
def main(_):
32-
model = embedding_gemma.build_model(FLAGS.checkpoint_path)
37+
model = embedding_gemma.build_model(
38+
FLAGS.checkpoint_path, final_l2_norm=FLAGS.final_l2_norm
39+
)
3340
model.eval()
3441
seq_len = max(FLAGS.prefill_seq_lens)
3542

litert_torch/generative/examples/embedding_gemma/embedding_gemma.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,10 @@ def forward(
9191
class EmbeddingGemma(nn.Module):
9292
"""EmbeddingGemma-300M model."""
9393

94-
def __init__(self, config: cfg.ModelConfig):
94+
def __init__(self, config: cfg.ModelConfig, final_l2_norm: bool = True):
9595
super().__init__()
9696
self.config = config
97+
self.final_l2_norm = final_l2_norm
9798

9899
# Token embeddings
99100
self.embedder = nn.Embedding(
@@ -175,12 +176,10 @@ def create_sliding_mask(
175176
def mean_pool(self, hidden_states, attention_mask):
176177
"""Mean pooling with attention mask."""
177178
if attention_mask is not None:
178-
input_mask_expanded = attention_mask.unsqueeze(-1).expand(
179-
hidden_states.size()
180-
).float()
181-
sum_embeddings = torch.sum(
182-
hidden_states * input_mask_expanded, dim=1
179+
input_mask_expanded = (
180+
attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
183181
)
182+
sum_embeddings = torch.sum(hidden_states * input_mask_expanded, dim=1)
184183
sum_mask = torch.clamp(input_mask_expanded.sum(dim=1), min=1e-9)
185184
return sum_embeddings / sum_mask
186185
else:
@@ -196,9 +195,7 @@ def forward(
196195
batch_size, seq_len = tokens.shape
197196

198197
if attention_mask is None:
199-
attention_mask = torch.ones(
200-
batch_size, seq_len, device=tokens.device
201-
)
198+
attention_mask = torch.ones(batch_size, seq_len, device=tokens.device)
202199

203200
x = self.embedder(tokens) * self.config.embedding_scale
204201

@@ -263,8 +260,11 @@ def forward(
263260
pooled_x = self.dense2(pooled_x)
264261

265262
# L2 normalization
266-
embedding = torch.nn.functional.normalize(pooled_x, p=2, dim=1)
267-
return embedding
263+
if self.final_l2_norm:
264+
embedding = torch.nn.functional.normalize(pooled_x, p=2, dim=1)
265+
return embedding
266+
267+
return pooled_x
268268

269269

270270
def get_model_config() -> cfg.ModelConfig:
@@ -349,11 +349,11 @@ def get_model_config() -> cfg.ModelConfig:
349349
return config
350350

351351

352-
def build_model(checkpoint_path) -> EmbeddingGemma:
352+
def build_model(checkpoint_path, final_l2_norm: bool = True) -> EmbeddingGemma:
353353
"""Build model and load weights from HuggingFace checkpoint."""
354354

355355
config = get_model_config()
356-
model = EmbeddingGemma(config)
356+
model = EmbeddingGemma(config, final_l2_norm)
357357

358358
print(f"Loading from checkpoint: {checkpoint_path}")
359359

litert_torch/generative/examples/embedding_gemma/verify.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,19 @@
3535
None,
3636
"Whether to enable the long input test.",
3737
)
38+
_FINAL_L2_NORM = flags.DEFINE_bool(
39+
"final_l2_norm",
40+
True,
41+
"Whether to apply final L2 normalization to the embeddings.",
42+
)
3843

3944

4045
def main(_):
4146
if not verify_util.verify_embedding_gemma_300m(
4247
checkpoint_dir=_CHECKPOINT.value,
4348
prompts=_PROMPTS.value,
4449
long_input_prompt_path=_LONG_INPUT_PROMPT_PATH.value,
50+
final_l2_norm=_FINAL_L2_NORM.value,
4551
):
4652
exit(1)
4753

litert_torch/generative/examples/embedding_gemma/verify_util.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,19 @@
2424

2525
DEFAULT_PROMPTS = [
2626
"What is the meaning of life?",
27-
"This is an example sentence."
27+
"This is an example sentence.",
2828
]
2929

3030
_MODEL_PATH = "google/embeddinggemma-300m"
3131
_LONG_INPUT_PROMPT_PATH = "long_input_prompt_test.txt"
3232

33+
3334
def verify_embedding_gemma_300m(
3435
checkpoint_dir: str = None,
3536
prompts: list[str] | None = None,
3637
long_input_prompt_path: str = None,
3738
atol: float = 0.0,
39+
final_l2_norm: bool = True,
3840
) -> bool:
3941
"""Verifies EmbeddingGemma-300M."""
4042

@@ -50,7 +52,9 @@ def verify_embedding_gemma_300m(
5052

5153
print(f"Loading reauthored model from: {checkpoint_dir}")
5254
try:
53-
reauthored_model = embedding_gemma.build_model(checkpoint_dir)
55+
reauthored_model = embedding_gemma.build_model(
56+
checkpoint_dir, final_l2_norm=final_l2_norm
57+
)
5458
reauthored_model.eval()
5559
except Exception as e: # pylint: disable=broad-except
5660
print(f"Failed to build or load reauthored model: {e}")
@@ -88,7 +92,9 @@ def verify_embedding_gemma_300m(
8892
print("\n--- Comparing Final Embeddings ---")
8993
with torch.no_grad():
9094
# Get embeddings from the original SentenceTransformer model.
91-
final_original_output = original_model(inputs) # pytype: disable=wrong-arg-types
95+
final_original_output = original_model(
96+
inputs
97+
) # pytype: disable=wrong-arg-types
9298
original_embedding = final_original_output["sentence_embedding"]
9399

94100
# Get embeddings from the reauthored model.

0 commit comments

Comments
 (0)