@@ -138,6 +138,30 @@ def test_llm_args_with_pydantic_options(self):
138138 assert llm_args .max_num_tokens == 256
139139 assert llm_args .max_seq_len == 128
140140
141+ def test_llm_args_with_model_kwargs_trt (self ):
142+ yaml_content = """
143+ model_kwargs:
144+ num_hidden_layers: 2
145+ """
146+ dict_content = self ._yaml_to_dict (yaml_content )
147+ llm_args = TrtLlmArgs (model = llama_model_path )
148+ llm_args_dict = update_llm_args_with_extra_dict (llm_args .model_dump (),
149+ dict_content )
150+ llm_args = TrtLlmArgs (** llm_args_dict )
151+ assert llm_args .model_kwargs ['num_hidden_layers' ] == 2
152+
153+ def test_llm_args_with_model_kwargs_pt (self ):
154+ yaml_content = """
155+ model_kwargs:
156+ num_hidden_layers: 2
157+ """
158+ dict_content = self ._yaml_to_dict (yaml_content )
159+ llm_args = TorchLlmArgs (model = llama_model_path )
160+ llm_args_dict = update_llm_args_with_extra_dict (llm_args .model_dump (),
161+ dict_content )
162+ llm_args = TorchLlmArgs (** llm_args_dict )
163+ assert llm_args .model_kwargs ['num_hidden_layers' ] == 2
164+
141165
142166def check_defaults (py_config_cls , pybind_config_cls ):
143167 py_config = py_config_cls ()
@@ -445,6 +469,17 @@ def test_dynamic_setattr(self):
445469 args = TorchLlmArgs (model = llama_model_path )
446470 args .invalid_arg = 1
447471
472+ @print_traceback_on_error
473+ def test_model_kwargs_with_num_hidden_layers (self ):
474+ """Test that model_kwargs can override num_hidden_layers."""
475+ from tensorrt_llm ._torch .model_config import ModelConfig
476+
477+ model_kwargs = {'num_hidden_layers' : 2 }
478+
479+ config = ModelConfig .from_pretrained (llama_model_path ,
480+ model_kwargs = model_kwargs )
481+ assert config .pretrained_config .num_hidden_layers == 2
482+
448483
449484class TestTrtLlmArgs :
450485
0 commit comments