@@ -124,16 +124,6 @@ def __init__(self,
124124 else :
125125 llm_args_cls = TrtLlmArgs
126126
127- # check the kwargs and raise ValueError directly
128- valid_keys = set (
129- list (llm_args_cls .model_fields .keys ()) +
130- ['_mpi_session' , 'backend' ])
131- for key in kwargs :
132- if key not in valid_keys :
133- raise ValueError (
134- f"{ self .__class__ .__name__ } got invalid argument: { key } "
135- )
136-
137127 self .args = llm_args_cls .from_kwargs (
138128 model = model ,
139129 tokenizer = tokenizer ,
@@ -596,7 +586,7 @@ def _build_model(self):
596586 max_num_tokens = max_num_tokens or build_config .max_num_tokens
597587 max_seq_len = max_seq_len or build_config .max_seq_len
598588
599- self . _executor_config = tllm .ExecutorConfig (
589+ executor_config = tllm .ExecutorConfig (
600590 max_beam_width = self .args .max_beam_width ,
601591 scheduler_config = PybindMirror .maybe_to_pybind (
602592 self .args .scheduler_config ),
@@ -608,20 +598,20 @@ def _build_model(self):
608598 if self .args .backend is None :
609599 # also set executor_config.max_seq_len in TRT workflow, to deduce default max_tokens
610600 if max_seq_len is not None :
611- self . _executor_config .max_seq_len = max_seq_len
601+ executor_config .max_seq_len = max_seq_len
612602 else :
613603 engine_config = EngineConfig .from_json_file (self ._engine_dir /
614604 "config.json" )
615- self . _executor_config .max_seq_len = engine_config .build_config .max_seq_len
605+ executor_config .max_seq_len = engine_config .build_config .max_seq_len
616606 if self .args .kv_cache_config is not None :
617- self . _executor_config .kv_cache_config = PybindMirror .maybe_to_pybind (
607+ executor_config .kv_cache_config = PybindMirror .maybe_to_pybind (
618608 self .args .kv_cache_config )
619609 if os .getenv ("FORCE_DETERMINISTIC" , "0" ) == "1" :
620610 # Disable KV cache reuse for deterministic mode
621- self . _executor_config .kv_cache_config .enable_block_reuse = False
622- self . _executor_config .kv_cache_config .enable_partial_reuse = False
611+ executor_config .kv_cache_config .enable_block_reuse = False
612+ executor_config .kv_cache_config .enable_partial_reuse = False
623613 if self .args .peft_cache_config is not None :
624- self . _executor_config .peft_cache_config = PybindMirror .maybe_to_pybind (
614+ executor_config .peft_cache_config = PybindMirror .maybe_to_pybind (
625615 self .args .peft_cache_config )
626616 elif self ._on_trt_backend and self .args .build_config .plugin_config .lora_plugin :
627617 engine_config = EngineConfig .from_json_file (self ._engine_dir /
@@ -630,16 +620,16 @@ def _build_model(self):
630620 max_lora_rank = lora_config .max_lora_rank
631621 num_lora_modules = engine_config .pretrained_config .num_hidden_layers * \
632622 len (lora_config .lora_target_modules + lora_config .missing_qkv_modules )
633- self . _executor_config .peft_cache_config = tllm .PeftCacheConfig (
623+ executor_config .peft_cache_config = tllm .PeftCacheConfig (
634624 num_device_module_layer = max_lora_rank * num_lora_modules *
635625 self .args .max_loras ,
636626 num_host_module_layer = max_lora_rank * num_lora_modules *
637627 self .args .max_cpu_loras ,
638628 )
639629 if self .args .decoding_config is not None :
640- self . _executor_config .decoding_config = self .args .decoding_config
630+ executor_config .decoding_config = self .args .decoding_config
641631 if self .args .guided_decoding_backend == 'xgrammar' :
642- self . _executor_config .guided_decoding_config = tllm .GuidedDecodingConfig (
632+ executor_config .guided_decoding_config = tllm .GuidedDecodingConfig (
643633 backend = tllm .GuidedDecodingConfig .GuidedDecodingBackend .
644634 XGRAMMAR ,
645635 ** _xgrammar_tokenizer_info (self .tokenizer ))
@@ -648,18 +638,18 @@ def _build_model(self):
648638 f"Unrecognized guided decoding backend { self .args .guided_decoding_backend } "
649639 )
650640
651- self . _executor_config .normalize_log_probs = self .args .normalize_log_probs
652- self . _executor_config .enable_chunked_context = self .args .enable_chunked_prefill
653- self . _executor_config .max_beam_width = self .args .max_beam_width or self .args .build_config .max_beam_width
641+ executor_config .normalize_log_probs = self .args .normalize_log_probs
642+ executor_config .enable_chunked_context = self .args .enable_chunked_prefill
643+ executor_config .max_beam_width = self .args .max_beam_width or self .args .build_config .max_beam_width
654644 if self ._on_trt_backend and self .args .extended_runtime_perf_knob_config is not None :
655- self . _executor_config .extended_runtime_perf_knob_config = PybindMirror .maybe_to_pybind (
645+ executor_config .extended_runtime_perf_knob_config = PybindMirror .maybe_to_pybind (
656646 self .args .extended_runtime_perf_knob_config )
657647 if self .args .cache_transceiver_config is not None :
658- self . _executor_config .cache_transceiver_config = PybindMirror .maybe_to_pybind (
648+ executor_config .cache_transceiver_config = PybindMirror .maybe_to_pybind (
659649 self .args .cache_transceiver_config )
660650 from tensorrt_llm ._torch .pyexecutor .config import update_executor_config
661651 update_executor_config (
662- self . _executor_config ,
652+ executor_config ,
663653 backend = self .args .backend ,
664654 pytorch_backend_config = self .args .get_pytorch_backend_config ()
665655 if self .args .backend in ["pytorch" , "_autodeploy" ] else None ,
@@ -671,14 +661,14 @@ def _build_model(self):
671661 trt_engine_dir = self ._engine_dir ,
672662 max_input_len = self .args .max_input_len ,
673663 max_seq_len = max_seq_len )
674- self . _executor_config .llm_parallel_config = self .args .parallel_config
664+ executor_config .llm_parallel_config = self .args .parallel_config
675665 return_logits = self .args .gather_generation_logits or (
676666 self ._on_trt_backend and self .args .build_config
677667 and self .args .build_config .gather_context_logits )
678668
679669 self ._executor = self ._executor_cls .create (
680670 self ._engine_dir ,
681- executor_config = self . _executor_config ,
671+ executor_config = executor_config ,
682672 batched_logits_processor = self .args .batched_logits_processor ,
683673 model_world_size = self .args .parallel_config .world_size ,
684674 mpi_session = self .mpi_session ,
0 commit comments