Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/sagemaker/modules/train/sm_recipes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def _get_trainining_recipe_gpu_model_name_and_script(model_type: str):
"mistral": ("mistral", "mistral_pretrain.py"),
"mixtral": ("mixtral", "mixtral_pretrain.py"),
"deepseek": ("deepseek", "deepseek_pretrain.py"),
"gpt_oss": ("custom_model", "custom_pretrain.py"),
}

for key in model_type_to_script:
Expand Down
1 change: 1 addition & 0 deletions src/sagemaker/pytorch/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def _get_training_recipe_gpu_script(code_dir, recipe, source_dir):
"mistral": ("mistral", "mistral_pretrain.py"),
"mixtral": ("mixtral", "mixtral_pretrain.py"),
"deepseek": ("deepseek", "deepseek_pretrain.py"),
"gpt_oss": ("custom_model", "custom_pretrain.py"),
}

if "model" not in recipe:
Expand Down
5 changes: 5 additions & 0 deletions tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,11 @@ def test_get_args_from_recipe_with_nova_and_role(mock_get_args_from_nova_recipe,
"script": "deepseek_pretrain.py",
"model_base_name": "deepseek",
},
{
"model_type": "gpt_oss",
"script": "custom_pretrain.py",
"model_base_name": "gpt_oss",
},
],
)
def test_get_trainining_recipe_gpu_model_name_and_script(test_case):
Expand Down
8 changes: 8 additions & 0 deletions tests/unit/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1087,6 +1087,14 @@ def test_training_recipe_for_trainium(sagemaker_session):
},
},
},
{
"script": "custom_pretrain.py",
"recipe": {
"model": {
"model_type": "gpt_oss",
},
},
},
],
)
@patch("shutil.copyfile")
Expand Down
Loading