@@ -143,27 +143,30 @@ def main(ctx, model: str, tokenizer: Optional[str],
143143 "kv_cache_config" : kv_cache_config ,
144144 }
145145
146- if extra_llm_api_options is not None :
147- llm_args = update_llm_args_with_extra_options (llm_args ,
148- extra_llm_api_options )
149-
150- profiler .start ("trtllm init" )
151146 if backend == 'pytorch' :
152- llm = PyTorchLLM ( ** llm_args ,
153- max_batch_size = max_batch_size ,
154- max_num_tokens = max_num_tokens ,
155- max_beam_width = max_beam_width ,
156- max_seq_len = max_seq_len )
147+ llm_cls = PyTorchLLM
148+ llm_args . update ( max_batch_size = max_batch_size ,
149+ max_num_tokens = max_num_tokens ,
150+ max_beam_width = max_beam_width ,
151+ max_seq_len = max_seq_len )
157152 elif backend == 'tensorrt' :
153+ llm_cls = LLM
158154 build_config = BuildConfig (max_batch_size = max_batch_size ,
159155 max_num_tokens = max_num_tokens ,
160156 max_beam_width = max_beam_width ,
161157 max_seq_len = max_seq_len )
162- llm = LLM ( ** llm_args , build_config = build_config )
158+ llm_args . update ( build_config = build_config )
163159 else :
164160 raise click .BadParameter (
165161 f"{ backend } is not a known backend, check help for available options." ,
166162 param_hint = "backend" )
163+
164+ if extra_llm_api_options is not None :
165+ llm_args = update_llm_args_with_extra_options (llm_args ,
166+ extra_llm_api_options )
167+
168+ profiler .start ("trtllm init" )
169+ llm = llm_cls (** llm_args )
167170 profiler .stop ("trtllm init" )
168171 elapsed_time = profiler .elapsed_time_in_sec ("trtllm init" )
169172 logger .info (f"TRTLLM initialization time: { elapsed_time :.3f} seconds." )
0 commit comments