Skip to content

Commit 9e68803

Browse files
authored
fix: pass enable_network_isolation when creating TF and SKLearn models (#1043)
1 parent 8fa5618 commit 9e68803

File tree

4 files changed

+26
-9
lines changed

4 files changed

+26
-9
lines changed

src/sagemaker/sklearn/estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ def create_model(
172172
image=self.image_name,
173173
sagemaker_session=self.sagemaker_session,
174174
vpc_config=self.get_vpc_config(vpc_config_override),
175+
enable_network_isolation=self.enable_network_isolation(),
175176
**kwargs
176177
)
177178

src/sagemaker/tensorflow/estimator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -583,6 +583,7 @@ def _create_tfs_model(
583583
entry_point=entry_point,
584584
source_dir=source_dir,
585585
dependencies=dependencies,
586+
enable_network_isolation=self.enable_network_isolation(),
586587
)
587588

588589
def _create_default_model(
@@ -612,6 +613,7 @@ def _create_default_model(
612613
sagemaker_session=self.sagemaker_session,
613614
vpc_config=self.get_vpc_config(vpc_config_override),
614615
dependencies=dependencies or self.dependencies,
616+
enable_network_isolation=self.enable_network_isolation(),
615617
)
616618

617619
def hyperparameters(self):

tests/unit/test_sklearn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
1+
# Copyright 2017-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License"). You
44
# may not use this file except in compliance with the License. A copy of
@@ -20,12 +20,10 @@
2020
from mock import Mock
2121
from mock import patch
2222

23-
2423
from sagemaker.sklearn import defaults
2524
from sagemaker.sklearn import SKLearn
2625
from sagemaker.sklearn import SKLearnPredictor, SKLearnModel
2726

28-
2927
DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
3028
SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py")
3129
TIMESTAMP = "2017-11-06-14:14:15.672"
@@ -183,6 +181,7 @@ def test_create_model_from_estimator(sagemaker_session, sklearn_version):
183181
py_version=PYTHON_VERSION,
184182
base_job_name="job",
185183
source_dir=source_dir,
184+
enable_network_isolation=True,
186185
)
187186

188187
job_name = "new_name"
@@ -198,6 +197,7 @@ def test_create_model_from_estimator(sagemaker_session, sklearn_version):
198197
assert model.container_log_level == container_log_level
199198
assert model.source_dir == source_dir
200199
assert model.vpc_config is None
200+
assert model.enable_network_isolation()
201201

202202

203203
def test_create_model_with_optional_params(sagemaker_session):

tests/unit/test_tf_estimator.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from sagemaker.fw_utils import create_image_uri
2323
from sagemaker.model import MODEL_SERVER_WORKERS_PARAM_NAME
2424
from sagemaker.session import s3_input
25-
from sagemaker.tensorflow import defaults, TensorFlow, TensorFlowModel, TensorFlowPredictor
25+
from sagemaker.tensorflow import defaults, serving, TensorFlow, TensorFlowModel, TensorFlowPredictor
2626
import sagemaker.tensorflow.estimator as tfe
2727

2828
DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
@@ -256,6 +256,7 @@ def test_create_model(sagemaker_session, tf_version):
256256
container_log_level=container_log_level,
257257
base_job_name="job",
258258
source_dir=source_dir,
259+
enable_network_isolation=True,
259260
)
260261

261262
job_name = "doing something"
@@ -271,6 +272,7 @@ def test_create_model(sagemaker_session, tf_version):
271272
assert model.container_log_level == container_log_level
272273
assert model.source_dir == source_dir
273274
assert model.vpc_config is None
275+
assert model.enable_network_isolation()
274276

275277

276278
def test_create_model_with_optional_params(sagemaker_session):
@@ -1002,11 +1004,23 @@ def test_script_mode_enabled(sagemaker_session):
10021004
assert tf._script_mode_enabled() is False
10031005

10041006

1005-
@patch("sagemaker.tensorflow.estimator.TensorFlow._create_tfs_model")
1006-
def test_script_mode_create_model(create_tfs_model, sagemaker_session):
1007-
tf = _build_tf(sagemaker_session=sagemaker_session, py_version="py3")
1008-
tf.create_model()
1009-
create_tfs_model.assert_called_once()
1007+
def test_script_mode_create_model(sagemaker_session):
1008+
tf = _build_tf(
1009+
sagemaker_session=sagemaker_session, py_version="py3", enable_network_isolation=True
1010+
)
1011+
tf._prepare_for_training() # set output_path and job name as if training happened
1012+
1013+
model = tf.create_model()
1014+
1015+
assert isinstance(model, serving.Model)
1016+
1017+
assert model.model_data == tf.model_data
1018+
assert model.role == tf.role
1019+
assert model.name == tf._current_job_name
1020+
assert model.container_log_level == tf.container_log_level
1021+
assert model._framework_version == "1.11"
1022+
assert model.sagemaker_session == sagemaker_session
1023+
assert model.enable_network_isolation()
10101024

10111025

10121026
@patch("sagemaker.utils.create_tar_file", MagicMock())

0 commit comments

Comments
 (0)