Skip to content

Commit fdc00c4

Browse files
committed
add test
1 parent b0741bb commit fdc00c4

File tree

1 file changed

+37
-32
lines changed

1 file changed

+37
-32
lines changed

tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py

Lines changed: 37 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -180,36 +180,41 @@ def test_get_args_from_recipe_compute(
180180
assert mock_trainium_args.call_count == 0
181181
assert args is None
182182

183-
@pytest.mark.parametrize(
184-
"test_case",
185-
[
186-
{
187-
"model_type": "llama_v3",
188-
"script": "llama_pretrain.py",
189-
"model_base_name": "llama_v3",
190-
},
191-
{
192-
"model_type": "mistral",
193-
"script": "mistral_pretrain.py",
194-
"model_base_name": "mistral",
195-
},
196-
{
197-
"model_type": "deepseek_llamav3",
198-
"script": "deepseek_pretrain.py",
199-
"model_base_name": "deepseek",
200-
},
201-
{
202-
"model_type": "deepseek_qwenv2",
203-
"script": "deepseek_pretrain.py",
204-
"model_base_name": "deepseek",
205-
},
206-
],
183+
@pytest.mark.parametrize(
184+
"test_case",
185+
[
186+
{
187+
"model_type": "llama_v4",
188+
"script": "llama_pretrain.py",
189+
"model_base_name": "llama"
190+
},
191+
{
192+
"model_type": "llama_v3",
193+
"script": "llama_pretrain.py",
194+
"model_base_name": "llama",
195+
},
196+
{
197+
"model_type": "mistral",
198+
"script": "mistral_pretrain.py",
199+
"model_base_name": "mistral",
200+
},
201+
{
202+
"model_type": "deepseek_llamav3",
203+
"script": "deepseek_pretrain.py",
204+
"model_base_name": "deepseek",
205+
},
206+
{
207+
"model_type": "deepseek_qwenv2",
208+
"script": "deepseek_pretrain.py",
209+
"model_base_name": "deepseek",
210+
},
211+
],
212+
)
213+
def test_get_trainining_recipe_gpu_model_name_and_script(test_case):
214+
model_type = test_case["model_type"]
215+
script = test_case["script"]
216+
model_base_name, script = _get_trainining_recipe_gpu_model_name_and_script(
217+
model_type
207218
)
208-
def test_get_trainining_recipe_gpu_model_name_and_script(test_case):
209-
model_type = test_case["model_type"]
210-
script = test_case["script"]
211-
model_base_name, script = _get_trainining_recipe_gpu_model_name_and_script(
212-
model_type, script
213-
)
214-
assert model_base_name == test_case["model_base_name"]
215-
assert script == test_case["script"]
219+
assert model_base_name == test_case["model_base_name"]
220+
assert script == test_case["script"]

0 commit comments

Comments
 (0)