@@ -180,36 +180,41 @@ def test_get_args_from_recipe_compute(
180180 assert mock_trainium_args .call_count == 0
181181 assert args is None
182182
183- @pytest .mark .parametrize (
184- "test_case" ,
185- [
186- {
187- "model_type" : "llama_v3" ,
188- "script" : "llama_pretrain.py" ,
189- "model_base_name" : "llama_v3" ,
190- },
191- {
192- "model_type" : "mistral" ,
193- "script" : "mistral_pretrain.py" ,
194- "model_base_name" : "mistral" ,
195- },
196- {
197- "model_type" : "deepseek_llamav3" ,
198- "script" : "deepseek_pretrain.py" ,
199- "model_base_name" : "deepseek" ,
200- },
201- {
202- "model_type" : "deepseek_qwenv2" ,
203- "script" : "deepseek_pretrain.py" ,
204- "model_base_name" : "deepseek" ,
205- },
206- ],
183+ @pytest .mark .parametrize (
184+ "test_case" ,
185+ [
186+ {
187+ "model_type" : "llama_v4" ,
188+ "script" : "llama_pretrain.py" ,
189+ "model_base_name" : "llama"
190+ },
191+ {
192+ "model_type" : "llama_v3" ,
193+ "script" : "llama_pretrain.py" ,
194+ "model_base_name" : "llama" ,
195+ },
196+ {
197+ "model_type" : "mistral" ,
198+ "script" : "mistral_pretrain.py" ,
199+ "model_base_name" : "mistral" ,
200+ },
201+ {
202+ "model_type" : "deepseek_llamav3" ,
203+ "script" : "deepseek_pretrain.py" ,
204+ "model_base_name" : "deepseek" ,
205+ },
206+ {
207+ "model_type" : "deepseek_qwenv2" ,
208+ "script" : "deepseek_pretrain.py" ,
209+ "model_base_name" : "deepseek" ,
210+ },
211+ ],
212+ )
213+ def test_get_trainining_recipe_gpu_model_name_and_script (test_case ):
214+ model_type = test_case ["model_type" ]
215+ script = test_case ["script" ]
216+ model_base_name , script = _get_trainining_recipe_gpu_model_name_and_script (
217+ model_type
207218 )
208- def test_get_trainining_recipe_gpu_model_name_and_script (test_case ):
209- model_type = test_case ["model_type" ]
210- script = test_case ["script" ]
211- model_base_name , script = _get_trainining_recipe_gpu_model_name_and_script (
212- model_type , script
213- )
214- assert model_base_name == test_case ["model_base_name" ]
215- assert script == test_case ["script" ]
219+ assert model_base_name == test_case ["model_base_name" ]
220+ assert script == test_case ["script" ]
0 commit comments