diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 0e4e582261..234f0c61fa 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -155,6 +155,7 @@ "2.3.0", "2.3.1", "2.4.1", + "2.5.1", ] TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = ["torch_distributed"] diff --git a/src/sagemaker/image_uri_config/pytorch-smp.json b/src/sagemaker/image_uri_config/pytorch-smp.json index 449726927a..53c2a75e13 100644 --- a/src/sagemaker/image_uri_config/pytorch-smp.json +++ b/src/sagemaker/image_uri_config/pytorch-smp.json @@ -9,7 +9,8 @@ "2.2": "2.3.1", "2.2.0": "2.3.1", "2.3.1": "2.5.0", - "2.4.1": "2.7.0" + "2.4.1": "2.7.0", + "2.5.1": "2.8.0" }, "versions": { "2.0.1": { @@ -186,6 +187,31 @@ "us-west-2": "658645717510" }, "repository": "smdistributed-modelparallel" + }, + "2.8.0": { + "py_versions": [ + "py311" + ], + "registries": { + "ap-northeast-1": "658645717510", + "ap-northeast-2": "658645717510", + "ap-northeast-3": "658645717510", + "ap-south-1": "658645717510", + "ap-southeast-1": "658645717510", + "ap-southeast-2": "658645717510", + "ca-central-1": "658645717510", + "eu-central-1": "658645717510", + "eu-north-1": "658645717510", + "eu-west-1": "658645717510", + "eu-west-2": "658645717510", + "eu-west-3": "658645717510", + "sa-east-1": "658645717510", + "us-east-1": "658645717510", + "us-east-2": "658645717510", + "us-west-1": "658645717510", + "us-west-2": "658645717510" + }, + "repository": "smdistributed-modelparallel" } } } diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index 7d277cd854..de6d622f78 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -701,12 +701,16 @@ def get_training_image_uri( if "modelparallel" in distribution["smdistributed"]: if distribution["smdistributed"]["modelparallel"].get("enabled", True): framework = "pytorch-smp" - if ( - "p5" in instance_type - or "2.1" in framework_version - or "2.2" in framework_version - or "2.3" in framework_version - or "2.4" in framework_version + supported_smp_pt_versions_cu124 = ("2.5",) + supported_smp_pt_versions_cu121 = ("2.1", "2.2", "2.3", "2.4") + if any( + pt_version in framework_version + for pt_version in supported_smp_pt_versions_cu124 + ): + container_version = "cu124" + elif "p5" in instance_type or any( + pt_version in framework_version + for pt_version in supported_smp_pt_versions_cu121 ): container_version = "cu121" else: diff --git a/tests/unit/sagemaker/image_uris/test_smp_v2.py b/tests/unit/sagemaker/image_uris/test_smp_v2.py index b1297822f7..3177384e7e 100644 --- a/tests/unit/sagemaker/image_uris/test_smp_v2.py +++ b/tests/unit/sagemaker/image_uris/test_smp_v2.py @@ -36,15 +36,18 @@ def test_smp_v2(load_config): for region in ACCOUNTS.keys(): for instance_type in CONTAINER_VERSIONS.keys(): cuda_vers = CONTAINER_VERSIONS[instance_type] - if ( - "2.1" in version - or "2.2" in version - or "2.3" in version - or "2.4" in version + supported_smp_pt_versions_cu124 = ("2.5",) + supported_smp_pt_versions_cu121 = ("2.1", "2.2", "2.3", "2.4") + if any( + pt_version in version for pt_version in supported_smp_pt_versions_cu124 + ): + cuda_vers = "cu124" + elif any( + pt_version in version for pt_version in supported_smp_pt_versions_cu121 ): cuda_vers = "cu121" - if "2.3.1" == version or "2.4.1" == version: + if version in ("2.3.1", "2.4.1", "2.5.1"): py_version = "py311" uri = image_uris.get_training_image_uri(