Skip to content

Commit 6ea5a1f

Browse files
committed
Added unittest
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
1 parent 5ceb858 commit 6ea5a1f

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed

tests/unittest/llmapi/test_llm_args.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,30 @@ def test_llm_args_with_pydantic_options(self):
138138
assert llm_args.max_num_tokens == 256
139139
assert llm_args.max_seq_len == 128
140140

141+
def test_llm_args_with_model_kwargs_trt(self):
142+
yaml_content = """
143+
model_kwargs:
144+
num_hidden_layers: 2
145+
"""
146+
dict_content = self._yaml_to_dict(yaml_content)
147+
llm_args = TrtLlmArgs(model=llama_model_path)
148+
llm_args_dict = update_llm_args_with_extra_dict(llm_args.model_dump(),
149+
dict_content)
150+
llm_args = TrtLlmArgs(**llm_args_dict)
151+
assert llm_args.model_kwargs['num_hidden_layers'] == 2
152+
153+
def test_llm_args_with_model_kwargs_pt(self):
154+
yaml_content = """
155+
model_kwargs:
156+
num_hidden_layers: 2
157+
"""
158+
dict_content = self._yaml_to_dict(yaml_content)
159+
llm_args = TorchLlmArgs(model=llama_model_path)
160+
llm_args_dict = update_llm_args_with_extra_dict(llm_args.model_dump(),
161+
dict_content)
162+
llm_args = TorchLlmArgs(**llm_args_dict)
163+
assert llm_args.model_kwargs['num_hidden_layers'] == 2
164+
141165

142166
def check_defaults(py_config_cls, pybind_config_cls):
143167
py_config = py_config_cls()
@@ -445,6 +469,17 @@ def test_dynamic_setattr(self):
445469
args = TorchLlmArgs(model=llama_model_path)
446470
args.invalid_arg = 1
447471

472+
@print_traceback_on_error
473+
def test_model_kwargs_with_num_hidden_layers(self):
474+
"""Test that model_kwargs can override num_hidden_layers."""
475+
from tensorrt_llm._torch.model_config import ModelConfig
476+
477+
model_kwargs = {'num_hidden_layers': 2}
478+
479+
config = ModelConfig.from_pretrained(llama_model_path,
480+
model_kwargs=model_kwargs)
481+
assert config.pretrained_config.num_hidden_layers == 2
482+
448483

449484
class TestTrtLlmArgs:
450485

0 commit comments

Comments
 (0)