File tree Expand file tree Collapse file tree 1 file changed +4
-15
lines changed
Expand file tree Collapse file tree 1 file changed +4
-15
lines changed Original file line number Diff line number Diff line change @@ -138,28 +138,17 @@ def test_llm_args_with_pydantic_options(self):
138138 assert llm_args .max_num_tokens == 256
139139 assert llm_args .max_seq_len == 128
140140
141- def test_llm_args_with_model_kwargs_trt (self ):
141+ @pytest .mark .parametrize ("llm_args_cls" , [TrtLlmArgs , TorchLlmArgs ])
142+ def test_llm_args_with_model_kwargs (self , llm_args_cls ):
142143 yaml_content = """
143144model_kwargs:
144145 num_hidden_layers: 2
145146 """
146147 dict_content = self ._yaml_to_dict (yaml_content )
147- llm_args = TrtLlmArgs (model = llama_model_path )
148- llm_args_dict = update_llm_args_with_extra_dict (llm_args .model_dump (),
149- dict_content )
150- llm_args = TrtLlmArgs (** llm_args_dict )
151- assert llm_args .model_kwargs ['num_hidden_layers' ] == 2
152-
153- def test_llm_args_with_model_kwargs_pt (self ):
154- yaml_content = """
155- model_kwargs:
156- num_hidden_layers: 2
157- """
158- dict_content = self ._yaml_to_dict (yaml_content )
159- llm_args = TorchLlmArgs (model = llama_model_path )
148+ llm_args = llm_args_cls (model = llama_model_path )
160149 llm_args_dict = update_llm_args_with_extra_dict (llm_args .model_dump (),
161150 dict_content )
162- llm_args = TorchLlmArgs (** llm_args_dict )
151+ llm_args = llm_args_cls (** llm_args_dict )
163152 assert llm_args .model_kwargs ['num_hidden_layers' ] == 2
164153
165154
You can’t perform that action at this time.
0 commit comments