Skip to content

Commit 13ffe52

Browse files
authored
[None][fix] Allow YAML config overwriting CLI args for trtllm-eval (#10296)
Signed-off-by: Enwei Zhu <[email protected]>
1 parent f3f0231 commit 13ffe52

File tree

1 file changed

+14
-11
lines changed

1 file changed

+14
-11
lines changed

tensorrt_llm/commands/eval.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -143,27 +143,30 @@ def main(ctx, model: str, tokenizer: Optional[str],
143143
"kv_cache_config": kv_cache_config,
144144
}
145145

146-
if extra_llm_api_options is not None:
147-
llm_args = update_llm_args_with_extra_options(llm_args,
148-
extra_llm_api_options)
149-
150-
profiler.start("trtllm init")
151146
if backend == 'pytorch':
152-
llm = PyTorchLLM(**llm_args,
153-
max_batch_size=max_batch_size,
154-
max_num_tokens=max_num_tokens,
155-
max_beam_width=max_beam_width,
156-
max_seq_len=max_seq_len)
147+
llm_cls = PyTorchLLM
148+
llm_args.update(max_batch_size=max_batch_size,
149+
max_num_tokens=max_num_tokens,
150+
max_beam_width=max_beam_width,
151+
max_seq_len=max_seq_len)
157152
elif backend == 'tensorrt':
153+
llm_cls = LLM
158154
build_config = BuildConfig(max_batch_size=max_batch_size,
159155
max_num_tokens=max_num_tokens,
160156
max_beam_width=max_beam_width,
161157
max_seq_len=max_seq_len)
162-
llm = LLM(**llm_args, build_config=build_config)
158+
llm_args.update(build_config=build_config)
163159
else:
164160
raise click.BadParameter(
165161
f"{backend} is not a known backend, check help for available options.",
166162
param_hint="backend")
163+
164+
if extra_llm_api_options is not None:
165+
llm_args = update_llm_args_with_extra_options(llm_args,
166+
extra_llm_api_options)
167+
168+
profiler.start("trtllm init")
169+
llm = llm_cls(**llm_args)
167170
profiler.stop("trtllm init")
168171
elapsed_time = profiler.elapsed_time_in_sec("trtllm init")
169172
logger.info(f"TRTLLM initialization time: {elapsed_time:.3f} seconds.")

0 commit comments

Comments
 (0)