Skip to content

Commit 9a8d066

Browse files
jswudiSong Jiangahsan-z-khan
authored
fix: create profiler specific unsupported regions (#2101)
Co-authored-by: Song Jiang <[email protected]> Co-authored-by: Ahsan Khan <[email protected]>
1 parent 25d5a05 commit 9a8d066

File tree

3 files changed

+45
-2
lines changed

3 files changed

+45
-2
lines changed

src/sagemaker/estimator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
UploadedCode,
5050
validate_source_dir,
5151
_region_supports_debugger,
52+
_region_supports_profiler,
5253
get_mp_parameters,
5354
)
5455
from sagemaker.inputs import TrainingInput
@@ -494,7 +495,7 @@ def _prepare_profiler_for_training(self):
494495
"""Set necessary values and do basic validations in profiler config and profiler rules.
495496
496497
When user explicitly set rules to an empty list, default profiler rule won't be enabled.
497-
Default profiler rule will be enabled when either:
498+
Default profiler rule will be enabled in supported regions when either:
498499
1. user doesn't specify any rules, i.e., rules=None; or
499500
2. user only specify debugger rules, i.e., rules=[Rule.sagemaker(...)]
500501
"""
@@ -503,7 +504,7 @@ def _prepare_profiler_for_training(self):
503504
raise RuntimeError("profiler_config cannot be set when disable_profiler is True.")
504505
if self.profiler_rules:
505506
raise RuntimeError("ProfilerRule cannot be set when disable_profiler is True.")
506-
elif _region_supports_debugger(self.sagemaker_session.boto_region_name):
507+
elif _region_supports_profiler(self.sagemaker_session.boto_region_name):
507508
if self.profiler_config is None:
508509
self.profiler_config = ProfilerConfig(s3_output_path=self.output_path)
509510
if self.rules is None or (self.rules and not self.profiler_rules):

src/sagemaker/fw_utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@
4949
)
5050

5151
DEBUGGER_UNSUPPORTED_REGIONS = ("us-iso-east-1",)
52+
PROFILER_UNSUPPORTED_REGIONS = ("us-iso-east-1", "cn-north-1", "cn-northwest-1")
53+
5254
SINGLE_GPU_INSTANCE_TYPES = ("ml.p2.xlarge", "ml.p3.2xlarge")
5355
SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES = (
5456
"ml.p3.16xlarge",
@@ -550,6 +552,19 @@ def _region_supports_debugger(region_name):
550552
return region_name.lower() not in DEBUGGER_UNSUPPORTED_REGIONS
551553

552554

555+
def _region_supports_profiler(region_name):
556+
"""Returns bool indicating whether region supports Amazon SageMaker Debugger profiling feature.
557+
558+
Args:
559+
region_name (str): Name of the region to check against.
560+
561+
Returns:
562+
bool: Whether or not the region supports Amazon SageMaker Debugger profiling feature.
563+
564+
"""
565+
return region_name.lower() not in PROFILER_UNSUPPORTED_REGIONS
566+
567+
553568
def validate_version_or_image_args(framework_version, py_version, image_uri):
554569
"""Checks if version or image arguments are specified.
555570

tests/unit/test_estimator.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
Rule,
3636
)
3737
from sagemaker.estimator import Estimator, EstimatorBase, Framework, _TrainingJob
38+
from sagemaker.fw_utils import PROFILER_UNSUPPORTED_REGIONS
3839
from sagemaker.inputs import ShuffleConfig
3940
from sagemaker.model import FrameworkModel
4041
from sagemaker.predictor import Predictor
@@ -632,6 +633,32 @@ def test_framework_with_profiler_config_without_s3_output_path(time, sagemaker_s
632633
]
633634

634635

636+
@pytest.mark.parametrize("region", PROFILER_UNSUPPORTED_REGIONS)
637+
def test_framework_with_no_default_profiler_in_unsupported_region(region):
638+
boto_mock = Mock(name="boto_session", region_name=region)
639+
sms = MagicMock(
640+
name="sagemaker_session",
641+
boto_session=boto_mock,
642+
boto_region_name=region,
643+
config=None,
644+
local_mode=False,
645+
s3_client=None,
646+
s3_resource=None,
647+
)
648+
f = DummyFramework(
649+
entry_point=SCRIPT_PATH,
650+
role=ROLE,
651+
sagemaker_session=sms,
652+
instance_count=INSTANCE_COUNT,
653+
instance_type=INSTANCE_TYPE,
654+
)
655+
f.fit("s3://mydata")
656+
sms.train.assert_called_once()
657+
_, args = sms.train.call_args
658+
assert args.get("profiler_config") is None
659+
assert args.get("profiler_rule_configs") is None
660+
661+
635662
def test_framework_with_profiler_config_and_profiler_disabled(sagemaker_session):
636663
with pytest.raises(RuntimeError) as error:
637664
f = DummyFramework(

0 commit comments

Comments
 (0)