1616"""Common utility functions to verify the reauthored models."""
1717
1818import logging
19- from typing import Any ,List
19+ from typing import Any , List , Optional
2020
2121from ai_edge_torch .generative .layers import kv_cache as kv_utils
2222from ai_edge_torch .generative .utilities .model_builder import ExportConfig
@@ -134,7 +134,7 @@ def generate(
134134 prompts : torch .Tensor ,
135135 max_new_tokens : int ,
136136 pixel_values : torch .Tensor = None ,
137- eos_token_id : int = 1 ,
137+ eos_token_id : Optional [ int ] = None ,
138138 ) -> torch .IntTensor :
139139 input_ids = prompts [0 ].int ().tolist ()
140140 tokens = torch .tensor ([input_ids ])
@@ -146,7 +146,7 @@ def generate(
146146 )
147147 generated_token = logits [0 ][- 1 ].argmax ().item ()
148148 input_ids .append (generated_token )
149- if generated_token == eos_token_id :
149+ if eos_token_id is not None and generated_token == eos_token_id :
150150 break
151151 tokens = torch .tensor ([[generated_token ]])
152152 input_pos = torch .tensor ([len (input_ids ) - 1 ])
@@ -253,7 +253,7 @@ def verify_model_with_prompts(
253253 outputs_reauthored = reauthored_model .generate (
254254 prompt_tokens ,
255255 max_new_tokens ,
256- eos_token_id = tokenizer .tokenizer . eos_token_id ,
256+ eos_token_id = getattr ( tokenizer .tokenizer , " eos_token_id" , None ) ,
257257 )
258258 response_reauthored = tokenizer .decode (outputs_reauthored [0 ])
259259 logging .info ("outputs from reauthored model: [[%s]]" , response_reauthored )
0 commit comments