Skip to content

Commit 1d94c99

Browse files
authored
fix: allow model with network isolation when creating a Transformer from an Estimator (#1394)
1 parent bc71d77 commit 1d94c99

File tree

5 files changed

+126
-25
lines changed

5 files changed

+126
-25
lines changed

src/sagemaker/estimator.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -825,6 +825,7 @@ def transformer(
825825
role=None,
826826
volume_kms_key=None,
827827
vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT,
828+
enable_network_isolation=None,
828829
):
829830
"""Return a ``Transformer`` that uses a SageMaker Model based on the
830831
training job. It reuses the SageMaker Session and base job name used by
@@ -863,8 +864,18 @@ def transformer(
863864
vpc_config_override (dict[str, list[str]]): Optional override for the
864865
VpcConfig set on the model.
865866
Default: use subnets and security groups from this Estimator.
867+
866868
* 'Subnets' (list[str]): List of subnet ids.
867869
* 'SecurityGroupIds' (list[str]): List of security group ids.
870+
871+
enable_network_isolation (bool): Specifies whether container will
872+
run in network isolation mode. Network isolation mode restricts
873+
the container access to outside networks (such as the internet).
874+
The container does not make any inbound or outbound network
875+
calls. If True, a channel named "code" will be created for any
876+
user entry script for inference. Also known as Internet-free mode.
877+
If not specified, this setting is taken from the estimator's
878+
current configuration.
868879
"""
869880
tags = tags or self.tags
870881

@@ -876,8 +887,13 @@ def transformer(
876887
model_name = self._current_job_name
877888
else:
878889
model_name = self.latest_training_job.name
890+
if enable_network_isolation is None:
891+
enable_network_isolation = self.enable_network_isolation()
892+
879893
model = self.create_model(
880-
vpc_config_override=vpc_config_override, model_kms_key=self.output_kms_key
894+
vpc_config_override=vpc_config_override,
895+
model_kms_key=self.output_kms_key,
896+
enable_network_isolation=enable_network_isolation,
881897
)
882898

883899
# not all create_model() implementations have the same kwargs
@@ -1354,14 +1370,16 @@ def predict_wrapper(endpoint, session):
13541370

13551371
role = role or self.role
13561372

1373+
if "enable_network_isolation" not in kwargs:
1374+
kwargs["enable_network_isolation"] = self.enable_network_isolation()
1375+
13571376
return Model(
13581377
self.model_data,
13591378
image or self.train_image(),
13601379
role,
13611380
vpc_config=self.get_vpc_config(vpc_config_override),
13621381
sagemaker_session=self.sagemaker_session,
13631382
predictor_cls=predictor_cls,
1364-
enable_network_isolation=self.enable_network_isolation(),
13651383
**kwargs
13661384
)
13671385

@@ -1878,6 +1896,7 @@ def transformer(
18781896
volume_kms_key=None,
18791897
entry_point=None,
18801898
vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT,
1899+
enable_network_isolation=None,
18811900
):
18821901
"""Return a ``Transformer`` that uses a SageMaker Model based on the
18831902
training job. It reuses the SageMaker Session and base job name used by
@@ -1922,9 +1941,19 @@ def transformer(
19221941
vpc_config_override (dict[str, list[str]]): Optional override for
19231942
the VpcConfig set on the model.
19241943
Default: use subnets and security groups from this Estimator.
1944+
19251945
* 'Subnets' (list[str]): List of subnet ids.
19261946
* 'SecurityGroupIds' (list[str]): List of security group ids.
19271947
1948+
enable_network_isolation (bool): Specifies whether container will
1949+
run in network isolation mode. Network isolation mode restricts
1950+
the container access to outside networks (such as the internet).
1951+
The container does not make any inbound or outbound network
1952+
calls. If True, a channel named "code" will be created for any
1953+
user entry script for inference. Also known as Internet-free mode.
1954+
If not specified, this setting is taken from the estimator's
1955+
current configuration.
1956+
19281957
Returns:
19291958
sagemaker.transformer.Transformer: a ``Transformer`` object that can be used to start a
19301959
SageMaker Batch Transform job.
@@ -1933,12 +1962,16 @@ def transformer(
19331962
tags = tags or self.tags
19341963

19351964
if self.latest_training_job is not None:
1965+
if enable_network_isolation is None:
1966+
enable_network_isolation = self.enable_network_isolation()
1967+
19361968
model = self.create_model(
19371969
role=role,
19381970
model_server_workers=model_server_workers,
19391971
entry_point=entry_point,
19401972
vpc_config_override=vpc_config_override,
19411973
model_kms_key=self.output_kms_key,
1974+
enable_network_isolation=enable_network_isolation,
19421975
)
19431976
model._create_sagemaker_model(instance_type, tags=tags)
19441977

src/sagemaker/sklearn/estimator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,9 @@ def create_model(
174174
else:
175175
image = None
176176

177+
if "enable_network_isolation" not in kwargs:
178+
kwargs["enable_network_isolation"] = self.enable_network_isolation()
179+
177180
return SKLearnModel(
178181
self.model_data,
179182
role,
@@ -189,7 +192,6 @@ def create_model(
189192
image=image or self.image_name,
190193
sagemaker_session=self.sagemaker_session,
191194
vpc_config=self.get_vpc_config(vpc_config_override),
192-
enable_network_isolation=self.enable_network_isolation(),
193195
**kwargs
194196
)
195197

src/sagemaker/tensorflow/estimator.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -791,6 +791,7 @@ def transformer(
791791
endpoint_type=None,
792792
entry_point=None,
793793
vpc_config_override=VPC_CONFIG_DEFAULT,
794+
enable_network_isolation=None,
794795
):
795796
"""Return a ``Transformer`` that uses a SageMaker Model based on the training job. It
796797
reuses the SageMaker Session and base job name used by the Estimator.
@@ -836,8 +837,18 @@ def transformer(
836837
vpc_config_override (dict[str, list[str]]): Optional override for
837838
the VpcConfig set on the model.
838839
Default: use subnets and security groups from this Estimator.
840+
839841
* 'Subnets' (list[str]): List of subnet ids.
840842
* 'SecurityGroupIds' (list[str]): List of security group ids.
843+
844+
enable_network_isolation (bool): Specifies whether container will
845+
run in network isolation mode. Network isolation mode restricts
846+
the container access to outside networks (such as the internet).
847+
The container does not make any inbound or outbound network
848+
calls. If True, a channel named "code" will be created for any
849+
user entry script for inference. Also known as Internet-free mode.
850+
If not specified, this setting is taken from the estimator's
851+
current configuration.
841852
"""
842853
role = role or self.role
843854

@@ -864,13 +875,18 @@ def transformer(
864875
sagemaker_session=self.sagemaker_session,
865876
)
866877

878+
if enable_network_isolation is None:
879+
enable_network_isolation = self.enable_network_isolation()
880+
867881
model = self.create_model(
868882
model_server_workers=model_server_workers,
869883
role=role,
870884
vpc_config_override=vpc_config_override,
871885
endpoint_type=endpoint_type,
872886
entry_point=entry_point,
887+
enable_network_isolation=enable_network_isolation,
873888
)
889+
874890
return model.transformer(
875891
instance_count,
876892
instance_type,

tests/unit/test_estimator.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -121,13 +121,17 @@ def create_model(
121121
model_server_workers=None,
122122
entry_point=None,
123123
vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT,
124+
enable_network_isolation=None,
124125
**kwargs
125126
):
127+
if enable_network_isolation is None:
128+
enable_network_isolation = self.enable_network_isolation()
129+
126130
return DummyFrameworkModel(
127131
self.sagemaker_session,
128132
vpc_config=self.get_vpc_config(vpc_config_override),
129133
entry_point=entry_point,
130-
enable_network_isolation=self.enable_network_isolation(),
134+
enable_network_isolation=enable_network_isolation,
131135
role=role,
132136
**kwargs
133137
)
@@ -1357,7 +1361,7 @@ def test_framework_transformer_creation_with_optional_params(name_from_image, sa
13571361
base_job_name=base_name,
13581362
subnets=vpc_config["Subnets"],
13591363
security_group_ids=vpc_config["SecurityGroupIds"],
1360-
enable_network_isolation=True,
1364+
enable_network_isolation=False,
13611365
)
13621366
fw.latest_training_job = _TrainingJob(sagemaker_session, JOB_NAME)
13631367

@@ -1387,6 +1391,7 @@ def test_framework_transformer_creation_with_optional_params(name_from_image, sa
13871391
role=new_role,
13881392
model_server_workers=1,
13891393
vpc_config_override=new_vpc_config,
1394+
enable_network_isolation=True,
13901395
)
13911396

13921397
sagemaker_session.create_model.assert_called_with(
@@ -1437,8 +1442,8 @@ def test_ensure_latest_training_job_failure(sagemaker_session):
14371442
assert "Estimator is not associated with a training job" in str(e)
14381443

14391444

1440-
@patch("sagemaker.estimator.Estimator.create_model", return_value=Mock())
1441-
def test_estimator_transformer_creation(sagemaker_session):
1445+
@patch("sagemaker.estimator.Estimator.create_model")
1446+
def test_estimator_transformer_creation(create_model, sagemaker_session):
14421447
estimator = Estimator(
14431448
image_name=IMAGE_NAME,
14441449
role=ROLE,
@@ -1450,6 +1455,12 @@ def test_estimator_transformer_creation(sagemaker_session):
14501455

14511456
transformer = estimator.transformer(INSTANCE_COUNT, INSTANCE_TYPE)
14521457

1458+
create_model.assert_called_with(
1459+
vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT,
1460+
model_kms_key=estimator.output_kms_key,
1461+
enable_network_isolation=False,
1462+
)
1463+
14531464
assert isinstance(transformer, Transformer)
14541465
assert transformer.sagemaker_session == sagemaker_session
14551466
assert transformer.instance_count == INSTANCE_COUNT
@@ -1458,26 +1469,29 @@ def test_estimator_transformer_creation(sagemaker_session):
14581469
assert transformer.tags is None
14591470

14601471

1461-
@patch("sagemaker.estimator.Estimator.create_model", return_value=Mock())
1462-
def test_estimator_transformer_creation_with_optional_params(sagemaker_session):
1472+
@patch("sagemaker.estimator.Estimator.create_model")
1473+
def test_estimator_transformer_creation_with_optional_params(create_model, sagemaker_session):
14631474
base_name = "foo"
1475+
kms_key = "key"
1476+
14641477
estimator = Estimator(
14651478
image_name=IMAGE_NAME,
14661479
role=ROLE,
14671480
train_instance_count=INSTANCE_COUNT,
14681481
train_instance_type=INSTANCE_TYPE,
14691482
sagemaker_session=sagemaker_session,
14701483
base_job_name=base_name,
1484+
output_kms_key=kms_key,
14711485
)
14721486
estimator.latest_training_job = _TrainingJob(sagemaker_session, JOB_NAME)
14731487

14741488
strategy = "MultiRecord"
14751489
assemble_with = "Line"
1476-
kms_key = "key"
14771490
accept = "text/csv"
14781491
max_concurrent_transforms = 1
14791492
max_payload = 6
14801493
env = {"FOO": "BAR"}
1494+
new_vpc_config = {"Subnets": ["x"], "SecurityGroupIds": ["y"]}
14811495

14821496
transformer = estimator.transformer(
14831497
INSTANCE_COUNT,
@@ -1492,6 +1506,12 @@ def test_estimator_transformer_creation_with_optional_params(sagemaker_session):
14921506
max_payload=max_payload,
14931507
env=env,
14941508
role=ROLE,
1509+
vpc_config_override=new_vpc_config,
1510+
enable_network_isolation=True,
1511+
)
1512+
1513+
create_model.assert_called_with(
1514+
vpc_config_override=new_vpc_config, model_kms_key=kms_key, enable_network_isolation=True
14951515
)
14961516

14971517
assert transformer.strategy == strategy

tests/unit/test_tf_estimator.py

Lines changed: 45 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ def test_create_model_with_optional_params(sagemaker_session):
336336

337337

338338
@patch("sagemaker.tensorflow.estimator.TensorFlow.create_model")
339-
def test_transformer_creation_with_endpoint_type(create_model, sagemaker_session):
339+
def test_transformer_creation_with_optional_args(create_model, sagemaker_session):
340340
model = Mock()
341341
create_model.return_value = model
342342

@@ -348,38 +348,67 @@ def test_transformer_creation_with_endpoint_type(create_model, sagemaker_session
348348
train_instance_type=INSTANCE_TYPE,
349349
)
350350
tf.latest_training_job = _TrainingJob(sagemaker_session, "some-job-name")
351+
352+
strategy = "SingleRecord"
353+
assemble_with = "Line"
354+
output_path = "s3://{}/batch-output".format(BUCKET_NAME)
355+
kms_key = "kms"
356+
accept_type = "text/bytes"
357+
env = {"foo": "bar"}
358+
max_concurrent_transforms = 3
359+
max_payload = 100
360+
tags = {"Key": "foo", "Value": "bar"}
361+
new_role = "role"
362+
model_server_workers = 2
363+
vpc_config = {"Subnets": ["1234"], "SecurityGroupIds": ["5678"]}
364+
351365
tf.transformer(
352366
INSTANCE_COUNT,
353367
INSTANCE_TYPE,
368+
strategy=strategy,
369+
assemble_with=assemble_with,
370+
output_path=output_path,
371+
output_kms_key=kms_key,
372+
accept=accept_type,
373+
env=env,
374+
max_concurrent_transforms=max_concurrent_transforms,
375+
max_payload=max_payload,
376+
tags=tags,
377+
role=new_role,
378+
model_server_workers=model_server_workers,
379+
volume_kms_key=kms_key,
354380
endpoint_type="tensorflow-serving",
355381
entry_point=SERVING_SCRIPT_FILE,
382+
vpc_config_override=vpc_config,
383+
enable_network_isolation=True,
356384
)
357385

358386
create_model.assert_called_with(
387+
model_server_workers=model_server_workers,
388+
role=new_role,
389+
vpc_config_override=vpc_config,
359390
endpoint_type="tensorflow-serving",
360-
model_server_workers=None,
361-
role=ROLE,
362-
vpc_config_override="VPC_CONFIG_DEFAULT",
363391
entry_point=SERVING_SCRIPT_FILE,
392+
enable_network_isolation=True,
364393
)
365394
model.transformer.assert_called_with(
366395
INSTANCE_COUNT,
367396
INSTANCE_TYPE,
368-
accept=None,
369-
assemble_with=None,
370-
env=None,
371-
max_concurrent_transforms=None,
372-
max_payload=None,
373-
output_kms_key=None,
374-
output_path=None,
375-
strategy=None,
376-
tags=None,
377-
volume_kms_key=None,
397+
accept=accept_type,
398+
assemble_with=assemble_with,
399+
env=env,
400+
max_concurrent_transforms=max_concurrent_transforms,
401+
max_payload=max_payload,
402+
output_kms_key=kms_key,
403+
output_path=output_path,
404+
strategy=strategy,
405+
tags=tags,
406+
volume_kms_key=kms_key,
378407
)
379408

380409

381410
@patch("sagemaker.tensorflow.estimator.TensorFlow.create_model")
382-
def test_transformer_creation_without_endpoint_type(create_model, sagemaker_session):
411+
def test_transformer_creation_without_optional_args(create_model, sagemaker_session):
383412
model = Mock()
384413
create_model.return_value = model
385414

@@ -399,6 +428,7 @@ def test_transformer_creation_without_endpoint_type(create_model, sagemaker_sess
399428
role=ROLE,
400429
vpc_config_override="VPC_CONFIG_DEFAULT",
401430
entry_point=None,
431+
enable_network_isolation=False,
402432
)
403433
model.transformer.assert_called_with(
404434
INSTANCE_COUNT,

0 commit comments

Comments
 (0)