Skip to content

Commit 953087e

Browse files
authored
fix: expose vpc_config_override in transformer() methods (#1042)
1 parent 90dbb1a commit 953087e

File tree

3 files changed

+64
-22
lines changed

3 files changed

+64
-22
lines changed

src/sagemaker/estimator.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
from sagemaker.session import Session
4848
from sagemaker.session import s3_input
4949
from sagemaker.transformer import Transformer
50-
from sagemaker.utils import base_name_from_image, name_from_base, name_from_image, get_config_value
50+
from sagemaker.utils import base_name_from_image, name_from_base, get_config_value
5151
from sagemaker import vpc_utils
5252

5353

@@ -667,6 +667,7 @@ def transformer(
667667
tags=None,
668668
role=None,
669669
volume_kms_key=None,
670+
vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT,
670671
):
671672
"""Return a ``Transformer`` that uses a SageMaker Model based on the
672673
training job. It reuses the SageMaker Session and base job name used by
@@ -702,6 +703,11 @@ def transformer(
702703
role from the Estimator will be used.
703704
volume_kms_key (str): Optional. KMS key ID for encrypting the volume
704705
attached to the ML compute instance (default: None).
706+
vpc_config_override (dict[str, list[str]]): Optional override for the
707+
VpcConfig set on the model.
708+
Default: use subnets and security groups from this Estimator.
709+
* 'Subnets' (list[str]): List of subnet ids.
710+
* 'SecurityGroupIds' (list[str]): List of security group ids.
705711
"""
706712
tags = tags or self.tags
707713

@@ -714,7 +720,7 @@ def transformer(
714720
else:
715721
model_name = self.latest_training_job.name
716722

717-
model = self.create_model()
723+
model = self.create_model(vpc_config_override=vpc_config_override)
718724

719725
# not all create_model() implementations have the same kwargs
720726
model.name = model_name
@@ -1635,6 +1641,7 @@ def transformer(
16351641
model_server_workers=None,
16361642
volume_kms_key=None,
16371643
entry_point=None,
1644+
vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT,
16381645
):
16391646
"""Return a ``Transformer`` that uses a SageMaker Model based on the
16401647
training job. It reuses the SageMaker Session and base job name used by
@@ -1676,25 +1683,29 @@ def transformer(
16761683
entry_point (str): Path (absolute or relative) to the local Python source file which
16771684
should be executed as the entry point to training. If not specified, the training
16781685
entry point is used.
1686+
vpc_config_override (dict[str, list[str]]): Optional override for
1687+
the VpcConfig set on the model.
1688+
Default: use subnets and security groups from this Estimator.
1689+
* 'Subnets' (list[str]): List of subnet ids.
1690+
* 'SecurityGroupIds' (list[str]): List of security group ids.
16791691
16801692
Returns:
16811693
sagemaker.transformer.Transformer: a ``Transformer`` object that can be used to start a
16821694
SageMaker Batch Transform job.
16831695
"""
16841696
role = role or self.role
1697+
tags = tags or self.tags
16851698

16861699
if self.latest_training_job is not None:
16871700
model = self.create_model(
1688-
role=role, model_server_workers=model_server_workers, entry_point=entry_point
1701+
role=role,
1702+
model_server_workers=model_server_workers,
1703+
entry_point=entry_point,
1704+
vpc_config_override=vpc_config_override,
16891705
)
1706+
model._create_sagemaker_model(instance_type, tags=tags)
16901707

1691-
container_def = model.prepare_container_def(instance_type)
1692-
model_name = model.name or name_from_image(container_def["Image"])
1693-
vpc_config = model.vpc_config
1694-
tags = tags or self.tags
1695-
self.sagemaker_session.create_model(
1696-
model_name, role, container_def, vpc_config, tags=tags
1697-
)
1708+
model_name = model.name
16981709
transform_env = model.env.copy()
16991710
if env is not None:
17001711
transform_env.update(env)
@@ -1706,7 +1717,6 @@ def transformer(
17061717
model_name = self._current_job_name
17071718
transform_env = env or {}
17081719

1709-
tags = tags or self.tags
17101720
return Transformer(
17111721
model_name,
17121722
instance_count,

src/sagemaker/tensorflow/estimator.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -704,6 +704,7 @@ def transformer(
704704
volume_kms_key=None,
705705
endpoint_type=None,
706706
entry_point=None,
707+
vpc_config_override=VPC_CONFIG_DEFAULT,
707708
):
708709
"""Return a ``Transformer`` that uses a SageMaker Model based on the training job. It
709710
reuses the SageMaker Session and base job name used by the Estimator.
@@ -746,13 +747,18 @@ def transformer(
746747
should be executed as the entry point to training. If not specified and
747748
``endpoint_type`` is 'tensorflow-serving', no entry point is used. If
748749
``endpoint_type`` is also ``None``, then the training entry point is used.
750+
vpc_config_override (dict[str, list[str]]): Optional override for
751+
the VpcConfig set on the model.
752+
Default: use subnets and security groups from this Estimator.
753+
* 'Subnets' (list[str]): List of subnet ids.
754+
* 'SecurityGroupIds' (list[str]): List of security group ids.
749755
"""
750756

751757
role = role or self.role
752758
model = self.create_model(
753759
model_server_workers=model_server_workers,
754760
role=role,
755-
vpc_config_override=VPC_CONFIG_DEFAULT,
761+
vpc_config_override=vpc_config_override,
756762
endpoint_type=endpoint_type,
757763
entry_point=entry_point,
758764
)

tests/unit/test_estimator.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import pytest
2222
from mock import ANY, MagicMock, Mock, patch
2323

24+
from sagemaker import vpc_utils
2425
from sagemaker.amazon.amazon_estimator import registry
2526
from sagemaker.algorithm import AlgorithmEstimator
2627
from sagemaker.estimator import Estimator, EstimatorBase, Framework, _TrainingJob
@@ -112,9 +113,19 @@ class DummyFramework(Framework):
112113
def train_image(self):
113114
return IMAGE_NAME
114115

115-
def create_model(self, role=None, model_server_workers=None, entry_point=None):
116+
def create_model(
117+
self,
118+
role=None,
119+
model_server_workers=None,
120+
entry_point=None,
121+
vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT,
122+
):
116123
return DummyFrameworkModel(
117-
self.sagemaker_session, vpc_config=self.get_vpc_config(), entry_point=entry_point
124+
self.sagemaker_session,
125+
vpc_config=self.get_vpc_config(vpc_config_override),
126+
entry_point=entry_point,
127+
enable_network_isolation=self.enable_network_isolation(),
128+
role=role,
118129
)
119130

120131
@classmethod
@@ -127,12 +138,11 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
127138

128139

129140
class DummyFrameworkModel(FrameworkModel):
130-
def __init__(self, sagemaker_session, entry_point=None, **kwargs):
141+
def __init__(self, sagemaker_session, entry_point=None, role=ROLE, **kwargs):
131142
super(DummyFrameworkModel, self).__init__(
132143
MODEL_DATA,
133144
MODEL_IMAGE,
134-
INSTANCE_TYPE,
135-
ROLE,
145+
role,
136146
entry_point or ENTRY_POINT,
137147
sagemaker_session=sagemaker_session,
138148
**kwargs
@@ -141,7 +151,7 @@ def __init__(self, sagemaker_session, entry_point=None, **kwargs):
141151
def create_predictor(self, endpoint_name):
142152
return None
143153

144-
def prepare_container_def(self, instance_type):
154+
def prepare_container_def(self, instance_type, accelerator_type=None):
145155
return MODEL_CONTAINER_DEF
146156

147157

@@ -1280,22 +1290,30 @@ def test_init_with_source_dir_s3(strftime, sagemaker_session):
12801290
assert fw._hyperparameters == expected_hyperparameters
12811291

12821292

1283-
@patch("sagemaker.estimator.name_from_image", return_value=MODEL_IMAGE)
1293+
@patch("sagemaker.model.utils.name_from_image", return_value=MODEL_IMAGE)
12841294
def test_framework_transformer_creation(name_from_image, sagemaker_session):
1295+
vpc_config = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]}
12851296
fw = DummyFramework(
12861297
entry_point=SCRIPT_PATH,
12871298
role=ROLE,
12881299
train_instance_count=INSTANCE_COUNT,
12891300
train_instance_type=INSTANCE_TYPE,
12901301
sagemaker_session=sagemaker_session,
1302+
subnets=vpc_config["Subnets"],
1303+
security_group_ids=vpc_config["SecurityGroupIds"],
12911304
)
12921305
fw.latest_training_job = _TrainingJob(sagemaker_session, JOB_NAME)
12931306

12941307
transformer = fw.transformer(INSTANCE_COUNT, INSTANCE_TYPE)
12951308

12961309
name_from_image.assert_called_with(MODEL_IMAGE)
12971310
sagemaker_session.create_model.assert_called_with(
1298-
MODEL_IMAGE, ROLE, MODEL_CONTAINER_DEF, None, tags=None
1311+
MODEL_IMAGE,
1312+
ROLE,
1313+
MODEL_CONTAINER_DEF,
1314+
tags=None,
1315+
vpc_config=vpc_config,
1316+
enable_network_isolation=False,
12991317
)
13001318

13011319
assert isinstance(transformer, Transformer)
@@ -1307,7 +1325,7 @@ def test_framework_transformer_creation(name_from_image, sagemaker_session):
13071325
assert transformer.env == {}
13081326

13091327

1310-
@patch("sagemaker.estimator.name_from_image", return_value=MODEL_IMAGE)
1328+
@patch("sagemaker.model.utils.name_from_image", return_value=MODEL_IMAGE)
13111329
def test_framework_transformer_creation_with_optional_params(name_from_image, sagemaker_session):
13121330
base_name = "foo"
13131331
vpc_config = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]}
@@ -1320,6 +1338,7 @@ def test_framework_transformer_creation_with_optional_params(name_from_image, sa
13201338
base_job_name=base_name,
13211339
subnets=vpc_config["Subnets"],
13221340
security_group_ids=vpc_config["SecurityGroupIds"],
1341+
enable_network_isolation=True,
13231342
)
13241343
fw.latest_training_job = _TrainingJob(sagemaker_session, JOB_NAME)
13251344

@@ -1331,6 +1350,7 @@ def test_framework_transformer_creation_with_optional_params(name_from_image, sa
13311350
max_payload = 6
13321351
env = {"FOO": "BAR"}
13331352
new_role = "dummy-model-role"
1353+
new_vpc_config = {"Subnets": ["x"], "SecurityGroupIds": ["y"]}
13341354

13351355
transformer = fw.transformer(
13361356
INSTANCE_COUNT,
@@ -1347,10 +1367,16 @@ def test_framework_transformer_creation_with_optional_params(name_from_image, sa
13471367
env=env,
13481368
role=new_role,
13491369
model_server_workers=1,
1370+
vpc_config_override=new_vpc_config,
13501371
)
13511372

13521373
sagemaker_session.create_model.assert_called_with(
1353-
MODEL_IMAGE, new_role, MODEL_CONTAINER_DEF, vpc_config, tags=TAGS
1374+
MODEL_IMAGE,
1375+
new_role,
1376+
MODEL_CONTAINER_DEF,
1377+
vpc_config=new_vpc_config,
1378+
tags=TAGS,
1379+
enable_network_isolation=True,
13541380
)
13551381
assert transformer.strategy == strategy
13561382
assert transformer.assemble_with == assemble_with

0 commit comments

Comments
 (0)