148148]
149149
150150
151- TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS = ["1.11" , "1.11.0" ]
152-
151+ TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS = ["1.13.1" ]
153152
154153TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = ["torch_distributed" ]
155-
154+ TRAINIUM_SUPPORTED_TORCH_DISTRIBUTED_FRAMEWORK_VERSIONS = [
155+ "1.11" ,
156+ "1.11.0" ,
157+ "1.12" ,
158+ "1.12.0" ,
159+ "1.12.1" ,
160+ "1.13.1" ,
161+ ]
156162
157163SMDISTRIBUTED_SUPPORTED_STRATEGIES = ["dataparallel" , "modelparallel" ]
158164
@@ -1055,9 +1061,8 @@ def validate_torch_distributed_distribution(
10551061 Raises:
10561062 ValueError: if
10571063 `py_version` is not python3 or
1058- `framework_version` is not in TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS
1064+ `framework_version` is not compatible with instance types
10591065 """
1060-
10611066 torch_distributed_enabled = False
10621067 if "torch_distributed" in distribution :
10631068 torch_distributed_enabled = distribution .get ("torch_distributed" ).get ("enabled" , False )
@@ -1066,30 +1071,36 @@ def validate_torch_distributed_distribution(
10661071 return
10671072
10681073 err_msg = ""
1074+
10691075 if not image_uri :
10701076 # ignore framework_version and py_version if image_uri is set
10711077 # in case image_uri is not set, then both are mandatory
1072- if framework_version not in TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS :
1073- err_msg += (
1074- f"Provided framework_version { framework_version } is not supported by"
1075- " torch_distributed.\n "
1076- "Please specify one of the supported framework versions:"
1077- f" { TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS } \n "
1078- )
10791078 if "py3" not in py_version :
10801079 err_msg += (
10811080 f"Provided py_version { py_version } is not supported by torch_distributed.\n "
1082- "Please specify py_version>=py3"
1081+ "Please specify py_version>=py3\n "
10831082 )
10841083
1085- # Check instance compatibility
1086- match = re .match (r"^ml[\._]([a-z\d]+)\.?\w*$" , instance_type )
1087- if match :
1088- if not match [1 ].startswith ("trn" ):
1084+ # Check instance and framework_version compatibility
1085+ if _is_gpu_instance (instance_type ):
1086+ if framework_version not in TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS :
1087+ err_msg += (
1088+ f"Provided framework_version { framework_version } is not supported by"
1089+ f" torch_distributed for instance { instance_type } .\n "
1090+ "Please specify one of the supported framework versions:"
1091+ f"{ TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS } \n "
1092+ )
1093+ elif _is_trainium_instance (instance_type ):
1094+ if framework_version not in TRAINIUM_SUPPORTED_TORCH_DISTRIBUTED_FRAMEWORK_VERSIONS :
1095+ err_msg += (
1096+ f"Provided framework_version { framework_version } is not supported by"
1097+ f" torch_distributed for instance { instance_type } .\n "
1098+ "Please specify one of the supported framework versions:"
1099+ f"{ TRAINIUM_SUPPORTED_TORCH_DISTRIBUTED_FRAMEWORK_VERSIONS } \n "
1100+ )
1101+ else :
10891102 err_msg += (
1090- "torch_distributed is currently supported only for trainium instances.\n "
1091- " Please refer https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#distributed-pytorch-training \n " # noqa E501 # pylint: disable=c0301
1092- "for information regarding distributed training on non-trainium instances"
1103+ "Currently torch_distributed is supported only for GPU and Trainium instances.\n "
10931104 )
10941105
10951106 # Check entry point type
@@ -1103,6 +1114,41 @@ def validate_torch_distributed_distribution(
11031114 raise ValueError (err_msg )
11041115
11051116
1117+ def _is_gpu_instance (instance_type ):
1118+ """Returns bool indicating whether instance_type supports GPU
1119+
1120+ Args:
1121+ instance_type (str): Name of the instance_type to check against.
1122+
1123+ Returns:
1124+ bool: Whether or not the instance_type supports GPU
1125+ """
1126+ if isinstance (instance_type , str ):
1127+ match = re .match (r"^ml[\._]([a-z\d]+)\.?\w*$" , instance_type )
1128+ if match :
1129+ if match [1 ].startswith ("p" ) or match [1 ].startswith ("g" ):
1130+ return True
1131+ if instance_type == "local_gpu" :
1132+ return True
1133+ return False
1134+
1135+
1136+ def _is_trainium_instance (instance_type ):
1137+ """Returns bool indicating whether instance_type is a Trainium instance
1138+
1139+ Args:
1140+ instance_type (str): Name of the instance_type to check against.
1141+
1142+ Returns:
1143+ bool: Whether or not the instance_type is a Trainium instance
1144+ """
1145+ if isinstance (instance_type , str ):
1146+ match = re .match (r"^ml[\._]([a-z\d]+)\.?\w*$" , instance_type )
1147+ if match and match [1 ].startswith ("trn" ):
1148+ return True
1149+ return False
1150+
1151+
11061152def python_deprecation_warning (framework , latest_supported_version ):
11071153 """Placeholder docstring"""
11081154 return PYTHON_2_DEPRECATION_WARNING .format (
0 commit comments