Skip to content

Commit 152d177

Browse files
authored
fix: preserve EnableNetworkIsolation setting in attach (#1063)
* fix: preserve EnableNetworkIsolation setting in attach
1 parent 76d46d0 commit 152d177

File tree

5 files changed

+17
-0
lines changed

5 files changed

+17
-0
lines changed

src/sagemaker/algorithm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def __init__(
5353
model_channel_name="model",
5454
metric_definitions=None,
5555
encrypt_inter_container_traffic=False,
56+
**kwargs # pylint: disable=W0613
5657
):
5758
"""Initialize an ``AlgorithmEstimator`` instance.
5859
@@ -162,6 +163,8 @@ def __init__(
162163
model_channel_name:
163164
metric_definitions:
164165
encrypt_inter_container_traffic:
166+
**kwargs: Additional kwargs. This is unused. It's only added for AlgorithmEstimator
167+
to ignore the irrelevant arguments.
165168
"""
166169
self.algorithm_arn = algorithm_arn
167170
super(AlgorithmEstimator, self).__init__(

src/sagemaker/amazon/amazon_estimator.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,14 @@ def __init__(
5959
default data location will be used.
6060
**kwargs:
6161
"""
62+
63+
if "enable_network_isolation" in kwargs:
64+
logger.debug(
65+
"removing unused enable_network_isolation argument: %s",
66+
str(kwargs["enable_network_isolation"]),
67+
)
68+
del kwargs["enable_network_isolation"]
69+
6270
super(AmazonAlgorithmEstimatorBase, self).__init__(
6371
role, train_instance_count, train_instance_type, **kwargs
6472
)

src/sagemaker/estimator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,8 @@ class constructor
620620
init_params["base_job_name"] = job_details["TrainingJobName"]
621621
init_params["output_path"] = job_details["OutputDataConfig"]["S3OutputPath"]
622622
init_params["output_kms_key"] = job_details["OutputDataConfig"]["KmsKeyId"]
623+
if "EnableNetworkIsolation" in job_details:
624+
init_params["enable_network_isolation"] = job_details["EnableNetworkIsolation"]
623625

624626
has_hps = "HyperParameters" in job_details
625627
init_params["hyperparameters"] = job_details["HyperParameters"] if has_hps else {}

tests/unit/test_estimator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
},
8080
"RoleArn": "arn:aws:iam::366:role/SageMakerRole",
8181
"ResourceConfig": {"VolumeSizeInGB": 30, "InstanceCount": 1, "InstanceType": "ml.c4.xlarge"},
82+
"EnableNetworkIsolation": False,
8283
"StoppingCondition": {"MaxRuntimeInSeconds": 24 * 60 * 60},
8384
"TrainingJobName": "neo",
8485
"TrainingJobStatus": "Completed",
@@ -671,6 +672,7 @@ def test_enable_cloudwatch_metrics(sagemaker_session):
671672
def test_attach_framework(sagemaker_session):
672673
returned_job_description = RETURNED_JOB_DESCRIPTION.copy()
673674
returned_job_description["VpcConfig"] = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]}
675+
returned_job_description["EnableNetworkIsolation"] = True
674676
sagemaker_session.sagemaker_client.describe_training_job = Mock(
675677
name="describe_training_job", return_value=returned_job_description
676678
)
@@ -694,6 +696,7 @@ def test_attach_framework(sagemaker_session):
694696
assert framework_estimator.security_group_ids == ["bar"]
695697
assert framework_estimator.encrypt_inter_container_traffic is False
696698
assert framework_estimator.tags == LIST_TAGS_RESULT["Tags"]
699+
assert framework_estimator.enable_network_isolation() is True
697700

698701

699702
def test_attach_without_hyperparameters(sagemaker_session):

tox.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ exclude =
2525
max-complexity = 10
2626

2727
ignore =
28+
C901,
2829
E203, # whitespace before ':': Black disagrees with and explicitly violates this.
2930
FI10,
3031
FI12,

0 commit comments

Comments
 (0)