Skip to content

Commit 5fb930e

Browse files
ai-edge-botcopybara-github
authored andcommitted
Fix an edge case when eos_token_id is not defined in tokenizer.
Also, fix a wrong reference link of paligemma2/ PiperOrigin-RevId: 712559701
1 parent 58a7cde commit 5fb930e

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

ai_edge_torch/generative/examples/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ Gemma is Google's open-source LLM. The model has both a 2B and 7B versions. See
77
## PaliGemma
88
PaliGemma is a multimodal LLM which gets images and text as input, then
99
generates text as output. See
10-
[model's Kaggle page](https://www.kaggle.com/models/google/paligemma2).
10+
[model's Kaggle page](https://www.kaggle.com/models/google/paligemma-2).
1111
The examples we provide are PaliGemma2 and 1 of 3B with 224 image size.
1212
The checkpoint for PaliGemma2 can be downloaded from
1313
[here](https://www.kaggle.com/models/google/paligemma-2/transformers/paligemma2-3b-pt-224).

ai_edge_torch/generative/utilities/verifier.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
"""Common utility functions to verify the reauthored models."""
1717

1818
import logging
19-
from typing import Any,List
19+
from typing import Any, List, Optional
2020

2121
from ai_edge_torch.generative.layers import kv_cache as kv_utils
2222
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
@@ -134,7 +134,7 @@ def generate(
134134
prompts: torch.Tensor,
135135
max_new_tokens: int,
136136
pixel_values: torch.Tensor = None,
137-
eos_token_id: int = 1,
137+
eos_token_id: Optional[int] = None,
138138
) -> torch.IntTensor:
139139
input_ids = prompts[0].int().tolist()
140140
tokens = torch.tensor([input_ids])
@@ -146,7 +146,7 @@ def generate(
146146
)
147147
generated_token = logits[0][-1].argmax().item()
148148
input_ids.append(generated_token)
149-
if generated_token == eos_token_id:
149+
if eos_token_id is not None and generated_token == eos_token_id:
150150
break
151151
tokens = torch.tensor([[generated_token]])
152152
input_pos = torch.tensor([len(input_ids) - 1])
@@ -253,7 +253,7 @@ def verify_model_with_prompts(
253253
outputs_reauthored = reauthored_model.generate(
254254
prompt_tokens,
255255
max_new_tokens,
256-
eos_token_id=tokenizer.tokenizer.eos_token_id,
256+
eos_token_id=getattr(tokenizer.tokenizer, "eos_token_id", None),
257257
)
258258
response_reauthored = tokenizer.decode(outputs_reauthored[0])
259259
logging.info("outputs from reauthored model: [[%s]]", response_reauthored)

0 commit comments

Comments
 (0)