@@ -180,36 +180,36 @@ def test_get_args_from_recipe_compute(
180180 assert mock_trainium_args .call_count == 0
181181 assert args is None
182182
183- @ pytest . mark . parametrize (
184- "test_case" ,
185- [
186- {
187- "model_type " : "llama_v3" ,
188- "script" : "llama_pretrain.py" ,
189- "model_base_name " : "llama_v3" ,
190- } ,
191- {
192- "model_type" : "mistral" ,
193- "script" : "mistral_pretrain.py" ,
194- "model_base_name " : "mistral" ,
195- } ,
196- {
197- "model_type" : "deepseek_llamav3" ,
198- "script" : "deepseek_pretrain.py" ,
199- "model_base_name " : "deepseek " ,
200- } ,
201- {
202- "model_type" : "deepseek_qwenv2" ,
203- "script" : "deepseek_pretrain.py" ,
204- "model_base_name " : "deepseek " ,
205- } ,
206- ] ,
207- )
208- def test_get_trainining_recipe_gpu_model_name_and_script ( test_case ):
209- model_type = test_case [ "model_type" ]
210- script = test_case [ "script" ]
211- model_base_name , script = _get_trainining_recipe_gpu_model_name_and_script (
212- model_type , script
213- )
214- assert model_base_name == test_case ["model_base_name" ]
215- assert script == test_case ["script" ]
183+
184+ @ pytest . mark . parametrize (
185+ "test_case" ,
186+ [
187+ { "model_type" : "llama_v4" , "script" : "llama_pretrain.py" , "model_base_name " : "llama" } ,
188+ {
189+ "model_type " : "llama_v3" ,
190+ "script" : "llama_pretrain.py" ,
191+ "model_base_name" : "llama" ,
192+ } ,
193+ {
194+ "model_type " : "mistral" ,
195+ "script" : "mistral_pretrain.py" ,
196+ "model_base_name" : "mistral" ,
197+ } ,
198+ {
199+ "model_type " : "deepseek_llamav3 " ,
200+ "script" : "deepseek_pretrain.py" ,
201+ "model_base_name" : "deepseek" ,
202+ } ,
203+ {
204+ "model_type " : "deepseek_qwenv2 " ,
205+ "script" : "deepseek_pretrain.py" ,
206+ "model_base_name" : "deepseek" ,
207+ },
208+ ],
209+ )
210+ def test_get_trainining_recipe_gpu_model_name_and_script ( test_case ):
211+ model_type = test_case [ "model_type" ]
212+ script = test_case [ " script" ]
213+ model_base_name , script = _get_trainining_recipe_gpu_model_name_and_script ( model_type )
214+ assert model_base_name == test_case ["model_base_name" ]
215+ assert script == test_case ["script" ]
0 commit comments