diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index be3658365a..9a0e46d1a0 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -152,7 +152,6 @@ "2.1.0", "2.1.2", "2.2.0", - "2.3.0", "2.3.1", ] diff --git a/src/sagemaker/image_uri_config/pytorch-smp.json b/src/sagemaker/image_uri_config/pytorch-smp.json index 61971e5128..ab1398666b 100644 --- a/src/sagemaker/image_uri_config/pytorch-smp.json +++ b/src/sagemaker/image_uri_config/pytorch-smp.json @@ -8,7 +8,7 @@ "2.1": "2.1.2", "2.2": "2.3.1", "2.2.0": "2.3.1", - "2.3": "2.4.0" + "2.3.1": "2.4.0" }, "versions": { "2.0.1": { diff --git a/tests/unit/sagemaker/image_uris/test_smp_v2.py b/tests/unit/sagemaker/image_uris/test_smp_v2.py index e9c8cec292..4fd1cc6179 100644 --- a/tests/unit/sagemaker/image_uris/test_smp_v2.py +++ b/tests/unit/sagemaker/image_uris/test_smp_v2.py @@ -27,6 +27,7 @@ def test_smp_v2(load_config): "torch_distributed": {"enabled": True}, "smdistributed": {"modelparallel": {"enabled": True}}, } + for processor in PROCESSORS: for version in VERSIONS: ACCOUNTS = load_config["training"]["versions"][version]["registries"] @@ -38,6 +39,11 @@ def test_smp_v2(load_config): if "2.1" in version or "2.2" in version or "2.3" in version: cuda_vers = "cu121" + if "2.3.1" == version: + py_version = "py311" + + print(version, py_version) + uri = image_uris.get_training_image_uri( region, framework="pytorch",