Skip to content

Commit b7a4792

Browse files
authored
fix: Add PT 2.1 as a supported framework for the smdistributed distribution (#4400)
* Allow smdistributed to be used on PyTorch 2.1 * Add PT 2.1 as supported framework for smdistributed * Add unit test to ensure smdistributed works with PT-2.1.0
1 parent 3e69cbe commit b7a4792

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

src/sagemaker/estimator.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3285,7 +3285,6 @@ class Framework(EstimatorBase):
32853285
UNSUPPORTED_DLC_IMAGE_FOR_SM_PARALLELISM = (
32863286
"2.0.1-gpu-py310-cu121",
32873287
"2.0-gpu-py310-cu121",
3288-
"2.1.0-gpu-py310",
32893288
)
32903289

32913290
def __init__(
@@ -3959,7 +3958,7 @@ def _distribution_configuration(self, distribution):
39593958
for unsupported_image in Framework.UNSUPPORTED_DLC_IMAGE_FOR_SM_PARALLELISM:
39603959
if (
39613960
unsupported_image in img_uri and not torch_distributed_enabled
3962-
): # disabling DLC images with CUDA12
3961+
): # disabling DLC images without SMDataParallel or SMModelParallel
39633962
raise ValueError(
39643963
f"SMDistributed is currently incompatible with DLC image: {img_uri}. "
39653964
"(Could be due to CUDA version being greater than 11.)"

tests/unit/test_fw_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -932,6 +932,7 @@ def test_validate_smdataparallel_args_not_raises():
932932
("ml.p3.16xlarge", "pytorch", "1.13.1", "py39", smdataparallel_enabled),
933933
("ml.p3.16xlarge", "pytorch", "2.0.0", "py310", smdataparallel_enabled),
934934
("ml.p3.16xlarge", "pytorch", "2.0.1", "py310", smdataparallel_enabled),
935+
("ml.p3.16xlarge", "pytorch", "2.1.0", "py310", smdataparallel_enabled),
935936
("ml.p3.16xlarge", "tensorflow", "2.4.1", "py3", smdataparallel_enabled_custom_mpi),
936937
("ml.p3.16xlarge", "tensorflow", "2.4.1", "py37", smdataparallel_enabled_custom_mpi),
937938
("ml.p3.16xlarge", "tensorflow", "2.4.3", "py3", smdataparallel_enabled_custom_mpi),
@@ -955,6 +956,7 @@ def test_validate_smdataparallel_args_not_raises():
955956
("ml.p3.16xlarge", "pytorch", "1.13.1", "py39", smdataparallel_enabled_custom_mpi),
956957
("ml.p3.16xlarge", "pytorch", "2.0.0", "py310", smdataparallel_enabled_custom_mpi),
957958
("ml.p3.16xlarge", "pytorch", "2.0.1", "py310", smdataparallel_enabled_custom_mpi),
959+
("ml.p3.16xlarge", "pytorch", "2.1.0", "py310", smdataparallel_enabled_custom_mpi),
958960
]
959961
for instance_type, framework_name, framework_version, py_version, distribution in good_args:
960962
fw_utils._validate_smdataparallel_args(

0 commit comments

Comments
 (0)