Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions litert_torch/generative/examples/embedding_gemma/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Embedding Gemma-300M

[Embedding Gemma](https://huggingface.co/google/embeddinggemma-300m) is a text embedding model based on the Gemma architecture. This example demonstrates how to reauthor the model using the LiteRT Torch Generative API and convert it to TFLite.

## Model Details

EmbeddingGemma-300M is an encoder-only model with 24 layers. It uses a combination of local sliding window attention and global attention. This implementation supports:
- Sliding window attention mask.
- Mean pooling of hidden states.
- Final dense projections and L2 normalization (optional, useful for Matryoshka embeddings).

## Requirements

To run this example and verify the results, you need the following packages:

```bash
pip install litert-torch transformers sentence-transformers safetensors
```

## Convert to TFLite

To convert the model to TFLite, use the `convert_to_tflite.py` script. You'll need the HuggingFace checkpoint for `google/embeddinggemma-300m`.

```bash
python convert_to_tflite.py
--checkpoint_path=<path_to_checkpoint>
--output_path=/tmp/
--quantize=dynamic_int8
--prefill_seq_lens=512
--final_l2_norm=True
```

### Conversion Flags

- `--checkpoint_path`: Path to the directory containing the model's `model.safetensors` and Dense projection layers.
- `--output_path`: Directory where the converted `.tflite` model will be saved.
- `--quantize`: Quantization scheme (e.g., `dynamic_int8`, `none`).
- `--prefill_seq_lens`: Defines the input sequence length for the converted TFLite model.
- `--final_l2_norm`: Whether to apply final L2 normalization to the embeddings. Defaults to `True`. Set to `False` if using Matryoshka embeddings.

## Verify the Model

You can verify the reauthored model's output against the original HuggingFace model using `verify.py`.

```bash
python verify.py
--checkpoint=<path_to_checkpoint>
--prompts="What is the meaning of life?"
--prompts="This is an example sentence."
```

The verification script compares the final embeddings produced by the original `sentence-transformers` implementation and the reauthored `litert_torch` implementation to ensure parity.
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,18 @@
flags = generative_converter.define_conversion_flags(
model_name="embedding_gemma"
)
flags.DEFINE_bool(
"final_l2_norm",
True,
"Whether to apply final L2 normalization to the embeddings.",
)
FLAGS = flags.FLAGS


def main(_):
model = embedding_gemma.build_model(FLAGS.checkpoint_path)
model = embedding_gemma.build_model(
FLAGS.checkpoint_path, final_l2_norm=FLAGS.final_l2_norm
)
model.eval()
seq_len = max(FLAGS.prefill_seq_lens)

Expand Down
26 changes: 13 additions & 13 deletions litert_torch/generative/examples/embedding_gemma/embedding_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,10 @@ def forward(
class EmbeddingGemma(nn.Module):
"""EmbeddingGemma-300M model."""

def __init__(self, config: cfg.ModelConfig):
def __init__(self, config: cfg.ModelConfig, final_l2_norm: bool = True):
super().__init__()
self.config = config
self.final_l2_norm = final_l2_norm

# Token embeddings
self.embedder = nn.Embedding(
Expand Down Expand Up @@ -175,12 +176,10 @@ def create_sliding_mask(
def mean_pool(self, hidden_states, attention_mask):
"""Mean pooling with attention mask."""
if attention_mask is not None:
input_mask_expanded = attention_mask.unsqueeze(-1).expand(
hidden_states.size()
).float()
sum_embeddings = torch.sum(
hidden_states * input_mask_expanded, dim=1
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
)
sum_embeddings = torch.sum(hidden_states * input_mask_expanded, dim=1)
sum_mask = torch.clamp(input_mask_expanded.sum(dim=1), min=1e-9)
return sum_embeddings / sum_mask
else:
Expand All @@ -196,9 +195,7 @@ def forward(
batch_size, seq_len = tokens.shape

if attention_mask is None:
attention_mask = torch.ones(
batch_size, seq_len, device=tokens.device
)
attention_mask = torch.ones(batch_size, seq_len, device=tokens.device)

x = self.embedder(tokens) * self.config.embedding_scale

Expand Down Expand Up @@ -263,8 +260,11 @@ def forward(
pooled_x = self.dense2(pooled_x)

# L2 normalization
embedding = torch.nn.functional.normalize(pooled_x, p=2, dim=1)
return embedding
if self.final_l2_norm:
embedding = torch.nn.functional.normalize(pooled_x, p=2, dim=1)
return embedding

return pooled_x


def get_model_config() -> cfg.ModelConfig:
Expand Down Expand Up @@ -349,11 +349,11 @@ def get_model_config() -> cfg.ModelConfig:
return config


def build_model(checkpoint_path) -> EmbeddingGemma:
def build_model(checkpoint_path, final_l2_norm: bool = True) -> EmbeddingGemma:
"""Build model and load weights from HuggingFace checkpoint."""

config = get_model_config()
model = EmbeddingGemma(config)
model = EmbeddingGemma(config, final_l2_norm)

print(f"Loading from checkpoint: {checkpoint_path}")

Expand Down
6 changes: 6 additions & 0 deletions litert_torch/generative/examples/embedding_gemma/verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,19 @@
None,
"Whether to enable the long input test.",
)
_FINAL_L2_NORM = flags.DEFINE_bool(
"final_l2_norm",
True,
"Whether to apply final L2 normalization to the embeddings.",
)


def main(_):
if not verify_util.verify_embedding_gemma_300m(
checkpoint_dir=_CHECKPOINT.value,
prompts=_PROMPTS.value,
long_input_prompt_path=_LONG_INPUT_PROMPT_PATH.value,
final_l2_norm=_FINAL_L2_NORM.value,
):
exit(1)

Expand Down
12 changes: 9 additions & 3 deletions litert_torch/generative/examples/embedding_gemma/verify_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,19 @@

DEFAULT_PROMPTS = [
"What is the meaning of life?",
"This is an example sentence."
"This is an example sentence.",
]

_MODEL_PATH = "google/embeddinggemma-300m"
_LONG_INPUT_PROMPT_PATH = "long_input_prompt_test.txt"


def verify_embedding_gemma_300m(
checkpoint_dir: str = None,
prompts: list[str] | None = None,
long_input_prompt_path: str = None,
atol: float = 0.0,
final_l2_norm: bool = True,
) -> bool:
"""Verifies EmbeddingGemma-300M."""

Expand All @@ -50,7 +52,9 @@ def verify_embedding_gemma_300m(

print(f"Loading reauthored model from: {checkpoint_dir}")
try:
reauthored_model = embedding_gemma.build_model(checkpoint_dir)
reauthored_model = embedding_gemma.build_model(
checkpoint_dir, final_l2_norm=final_l2_norm
)
reauthored_model.eval()
except Exception as e: # pylint: disable=broad-except
print(f"Failed to build or load reauthored model: {e}")
Expand Down Expand Up @@ -88,7 +92,9 @@ def verify_embedding_gemma_300m(
print("\n--- Comparing Final Embeddings ---")
with torch.no_grad():
# Get embeddings from the original SentenceTransformer model.
final_original_output = original_model(inputs) # pytype: disable=wrong-arg-types
final_original_output = original_model(
inputs
) # pytype: disable=wrong-arg-types
original_embedding = final_original_output["sentence_embedding"]

# Get embeddings from the reauthored model.
Expand Down