File tree Expand file tree Collapse file tree 3 files changed +26
-1
lines changed Expand file tree Collapse file tree 3 files changed +26
-1
lines changed Original file line number Diff line number Diff line change 4444 UploadedCode ,
4545 _region_supports_debugger ,
4646 _region_supports_profiler ,
47+ _instance_type_supports_profiler ,
4748 get_mp_parameters ,
4849 tar_and_upload_dir ,
4950 validate_source_dir ,
@@ -592,7 +593,9 @@ def __init__(
592593
593594 self .max_retry_attempts = max_retry_attempts
594595
595- if not _region_supports_profiler (self .sagemaker_session .boto_region_name ):
596+ if not _region_supports_profiler (
597+ self .sagemaker_session .boto_region_name
598+ ) or _instance_type_supports_profiler (self .instance_type ):
596599 self .disable_profiler = True
597600
598601 self .profiler_rule_configs = None
Original file line number Diff line number Diff line change @@ -1074,6 +1074,22 @@ def _region_supports_profiler(region_name):
10741074 return region_name .lower () not in PROFILER_UNSUPPORTED_REGIONS
10751075
10761076
1077+ def _instance_type_supports_profiler (instance_type ):
1078+ """Returns bool indicating whether instance_type supports SageMaker Debugger profiling feature.
1079+
1080+ Args:
1081+ instance_type (str): Name of the instance_type to check against.
1082+
1083+ Returns:
1084+ bool: Whether or not the region supports Amazon SageMaker Debugger profiling feature.
1085+ """
1086+ if isinstance (instance_type , str ):
1087+ match = re .match (r"^ml[\._]([a-z\d]+)\.?\w*$" , instance_type )
1088+ if match and match [1 ].startswith ("trn" ):
1089+ return True
1090+ return False
1091+
1092+
10771093def validate_version_or_image_args (framework_version , py_version , image_uri ):
10781094 """Checks if version or image arguments are specified.
10791095
Original file line number Diff line number Diff line change @@ -1040,3 +1040,9 @@ def test_validate_unsupported_distributions_trainium_raises():
10401040 distribution = smdataparallel_enabled ,
10411041 instance_type = "ml.trn1.32xlarge" ,
10421042 )
1043+
1044+
1045+ def test_instance_type_supports_profiler ():
1046+ assert fw_utils ._instance_type_supports_profiler ("ml.trn1.xlarge" ) is True
1047+ assert fw_utils ._instance_type_supports_profiler ("ml.m4.xlarge" ) is False
1048+ assert fw_utils ._instance_type_supports_profiler ("local" ) is False
You can’t perform that action at this time.
0 commit comments