@@ -21,7 +21,7 @@ class TrainingCompilerConfig(object):
2121 """The SageMaker Training Compiler configuration class."""
2222
2323 DEBUG_PATH = "/opt/ml/output/data/compiler/"
24- SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3" , "g4dn" , "p4 " ]
24+ SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3" , "g4dn" , "p4d" , "g5 " ]
2525
2626 HP_ENABLE_COMPILER = "sagemaker_training_compiler_enabled"
2727 HP_ENABLE_DEBUG = "sagemaker_training_compiler_debug_mode"
@@ -123,7 +123,7 @@ def validate(
123123 """Checks if SageMaker Training Compiler is configured correctly.
124124
125125 Args:
126- estimator (str ): A estimator object
126+ estimator (:class:`sagemaker.estimator.Estimator` ): An estimator object.
127127 When SageMaker Training Compiler is enabled, it validates if
128128 the estimator is configured to be compatible with Training Compiler.
129129
@@ -132,31 +132,34 @@ def validate(
132132 ValueError: Raised if the requested configuration is not compatible
133133 with SageMaker Training Compiler.
134134 """
135-
136- if "local" not in estimator .instance_type :
137- requested_instance_class = estimator .instance_type .split ("." )[
138- 1
139- ] # Expecting ml.class.size
140- if not any (
141- [
142- requested_instance_class .startswith (i )
143- for i in cls .SUPPORTED_INSTANCE_CLASS_PREFIXES
144- ]
145- ):
135+ if estimator .instance_type :
136+ if "local" not in estimator .instance_type :
137+ requested_instance_class = estimator .instance_type .split ("." )[
138+ 1
139+ ] # Expecting ml.class.size
140+ if not any (
141+ [
142+ requested_instance_class .startswith (i )
143+ for i in cls .SUPPORTED_INSTANCE_CLASS_PREFIXES
144+ ]
145+ ):
146+ error_helper_string = (
147+ "Unsupported Instance class {}."
148+ "SageMaker Training Compiler only supports {}"
149+ )
150+ error_helper_string = error_helper_string .format (
151+ requested_instance_class , cls .SUPPORTED_INSTANCE_CLASS_PREFIXES
152+ )
153+ raise ValueError (error_helper_string )
154+ elif estimator .instance_type == "local" :
146155 error_helper_string = (
147- "Unsupported Instance class {}. SageMaker Training Compiler only supports {}"
156+ "SageMaker Training Compiler doesn't support local mode."
157+ "It only supports the following GPU instances: {}"
148158 )
149159 error_helper_string = error_helper_string .format (
150- requested_instance_class , cls .SUPPORTED_INSTANCE_CLASS_PREFIXES
160+ cls .SUPPORTED_INSTANCE_CLASS_PREFIXES
151161 )
152162 raise ValueError (error_helper_string )
153- elif estimator .instance_type == "local" :
154- error_helper_string = (
155- "The local mode is not supported by SageMaker Training Compiler."
156- "It only supports the following GPU instances: {}"
157- )
158- error_helper_string = error_helper_string .format (cls .SUPPORTED_INSTANCE_CLASS_PREFIXES )
159- raise ValueError (error_helper_string )
160163
161164 if estimator .distribution and "smdistributed" in estimator .distribution :
162165 raise ValueError (
@@ -180,3 +183,12 @@ def validate(
180183 estimator .debugger_hook_config , estimator .disable_profiler
181184 )
182185 logger .warning (helper_string )
186+
187+ if estimator .instance_groups :
188+ raise ValueError (
189+ "SageMaker Training Compiler currently only supports homogeneous clusters of "
190+ "the following GPU instance families: {}. Please use the 'instance_type' "
191+ "and 'instance_count' parameters instead of 'instance_groups'" .format (
192+ cls .SUPPORTED_INSTANCE_CLASS_PREFIXES
193+ )
194+ )
0 commit comments