Skip to content

Commit dcf507e

Browse files
authored
Update gen_model_answer.py
Signed-off-by: Chenjie Luo <[email protected]>
1 parent 8845aba commit dcf507e

File tree

1 file changed

+11
-15
lines changed

1 file changed

+11
-15
lines changed

examples/llm_eval/gen_model_answer.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -178,21 +178,18 @@ def get_model_answers(
178178
nim_model=None,
179179
):
180180
# Model Optimizer modification
181+
tokenizer = get_tokenizer(model_path, trust_remote_code=args.trust_remote_code)
181182
if checkpoint_dir:
182-
tokenizer = get_tokenizer(model_path, trust_remote_code=args.trust_remote_code)
183-
if checkpoint_dir:
184-
# get model type
185-
last_part = os.path.basename(checkpoint_dir)
186-
model_type = last_part.split("_")[0]
187-
# Some models require to set pad_token and eos_token based on external config (e.g., qwen)
188-
if model_type == "qwen":
189-
tokenizer.pad_token = tokenizer.convert_ids_to_tokens(151643)
190-
tokenizer.eos_token = tokenizer.convert_ids_to_tokens(151643)
191-
192-
assert LLM is not None, "tensorrt_llm APIs could not be imported."
193-
model = LLM(checkpoint_dir, tokenizer=tokenizer)
194-
else:
195-
raise ValueError("checkpoint_dir is required for TensorRT LLM inference.")
183+
# get model type
184+
last_part = os.path.basename(checkpoint_dir)
185+
model_type = last_part.split("_")[0]
186+
# Some models require to set pad_token and eos_token based on external config (e.g., qwen)
187+
if model_type == "qwen":
188+
tokenizer.pad_token = tokenizer.convert_ids_to_tokens(151643)
189+
tokenizer.eos_token = tokenizer.convert_ids_to_tokens(151643)
190+
191+
assert LLM is not None, "tensorrt_llm APIs could not be imported."
192+
model = LLM(checkpoint_dir, tokenizer=tokenizer)
196193
elif not nim_model:
197194
model, _ = load_model(
198195
model_path,
@@ -205,7 +202,6 @@ def get_model_answers(
205202
cpu_offloading=False,
206203
debug=False,
207204
)
208-
tokenizer = get_tokenizer(model_path, trust_remote_code=args.trust_remote_code)
209205
if args.quant_cfg:
210206
quantize_model(
211207
model,

0 commit comments

Comments
 (0)