@@ -180,36 +180,41 @@ def test_get_args_from_recipe_compute(
180
180
assert mock_trainium_args .call_count == 0
181
181
assert args is None
182
182
183
- @pytest .mark .parametrize (
184
- "test_case" ,
185
- [
186
- {
187
- "model_type" : "llama_v3" ,
188
- "script" : "llama_pretrain.py" ,
189
- "model_base_name" : "llama_v3" ,
190
- },
191
- {
192
- "model_type" : "mistral" ,
193
- "script" : "mistral_pretrain.py" ,
194
- "model_base_name" : "mistral" ,
195
- },
196
- {
197
- "model_type" : "deepseek_llamav3" ,
198
- "script" : "deepseek_pretrain.py" ,
199
- "model_base_name" : "deepseek" ,
200
- },
201
- {
202
- "model_type" : "deepseek_qwenv2" ,
203
- "script" : "deepseek_pretrain.py" ,
204
- "model_base_name" : "deepseek" ,
205
- },
206
- ],
183
+ @pytest .mark .parametrize (
184
+ "test_case" ,
185
+ [
186
+ {
187
+ "model_type" : "llama_v4" ,
188
+ "script" : "llama_pretrain.py" ,
189
+ "model_base_name" : "llama"
190
+ },
191
+ {
192
+ "model_type" : "llama_v3" ,
193
+ "script" : "llama_pretrain.py" ,
194
+ "model_base_name" : "llama" ,
195
+ },
196
+ {
197
+ "model_type" : "mistral" ,
198
+ "script" : "mistral_pretrain.py" ,
199
+ "model_base_name" : "mistral" ,
200
+ },
201
+ {
202
+ "model_type" : "deepseek_llamav3" ,
203
+ "script" : "deepseek_pretrain.py" ,
204
+ "model_base_name" : "deepseek" ,
205
+ },
206
+ {
207
+ "model_type" : "deepseek_qwenv2" ,
208
+ "script" : "deepseek_pretrain.py" ,
209
+ "model_base_name" : "deepseek" ,
210
+ },
211
+ ],
212
+ )
213
+ def test_get_trainining_recipe_gpu_model_name_and_script (test_case ):
214
+ model_type = test_case ["model_type" ]
215
+ script = test_case ["script" ]
216
+ model_base_name , script = _get_trainining_recipe_gpu_model_name_and_script (
217
+ model_type
207
218
)
208
- def test_get_trainining_recipe_gpu_model_name_and_script (test_case ):
209
- model_type = test_case ["model_type" ]
210
- script = test_case ["script" ]
211
- model_base_name , script = _get_trainining_recipe_gpu_model_name_and_script (
212
- model_type , script
213
- )
214
- assert model_base_name == test_case ["model_base_name" ]
215
- assert script == test_case ["script" ]
219
+ assert model_base_name == test_case ["model_base_name" ]
220
+ assert script == test_case ["script" ]
0 commit comments