diff --git a/src/sagemaker/modules/train/sm_recipes/utils.py b/src/sagemaker/modules/train/sm_recipes/utils.py index 549645cbe2..6b39add6cd 100644 --- a/src/sagemaker/modules/train/sm_recipes/utils.py +++ b/src/sagemaker/modules/train/sm_recipes/utils.py @@ -129,7 +129,7 @@ def _get_trainining_recipe_gpu_model_name_and_script(model_type: str): """Get the model base name and script for the training recipe.""" model_type_to_script = { - "llama_v3": ("llama", "llama_pretrain.py"), + "llama": ("llama", "llama_pretrain.py"), "mistral": ("mistral", "mistral_pretrain.py"), "mixtral": ("mixtral", "mixtral_pretrain.py"), "deepseek": ("deepseek", "deepseek_pretrain.py"), diff --git a/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py b/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py index f5f7ceb083..585a4d2745 100644 --- a/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py +++ b/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py @@ -180,36 +180,36 @@ def test_get_args_from_recipe_compute( assert mock_trainium_args.call_count == 0 assert args is None - @pytest.mark.parametrize( - "test_case", - [ - { - "model_type": "llama_v3", - "script": "llama_pretrain.py", - "model_base_name": "llama_v3", - }, - { - "model_type": "mistral", - "script": "mistral_pretrain.py", - "model_base_name": "mistral", - }, - { - "model_type": "deepseek_llamav3", - "script": "deepseek_pretrain.py", - "model_base_name": "deepseek", - }, - { - "model_type": "deepseek_qwenv2", - "script": "deepseek_pretrain.py", - "model_base_name": "deepseek", - }, - ], - ) - def test_get_trainining_recipe_gpu_model_name_and_script(test_case): - model_type = test_case["model_type"] - script = test_case["script"] - model_base_name, script = _get_trainining_recipe_gpu_model_name_and_script( - model_type, script - ) - assert model_base_name == test_case["model_base_name"] - assert script == test_case["script"] + +@pytest.mark.parametrize( + "test_case", + [ + {"model_type": "llama_v4", "script": "llama_pretrain.py", "model_base_name": "llama"}, + { + "model_type": "llama_v3", + "script": "llama_pretrain.py", + "model_base_name": "llama", + }, + { + "model_type": "mistral", + "script": "mistral_pretrain.py", + "model_base_name": "mistral", + }, + { + "model_type": "deepseek_llamav3", + "script": "deepseek_pretrain.py", + "model_base_name": "deepseek", + }, + { + "model_type": "deepseek_qwenv2", + "script": "deepseek_pretrain.py", + "model_base_name": "deepseek", + }, + ], +) +def test_get_trainining_recipe_gpu_model_name_and_script(test_case): + model_type = test_case["model_type"] + script = test_case["script"] + model_base_name, script = _get_trainining_recipe_gpu_model_name_and_script(model_type) + assert model_base_name == test_case["model_base_name"] + assert script == test_case["script"]