21
21
import pytest
22
22
from mock import ANY , MagicMock , Mock , patch
23
23
24
+ from sagemaker import vpc_utils
24
25
from sagemaker .amazon .amazon_estimator import registry
25
26
from sagemaker .algorithm import AlgorithmEstimator
26
27
from sagemaker .estimator import Estimator , EstimatorBase , Framework , _TrainingJob
@@ -112,9 +113,19 @@ class DummyFramework(Framework):
112
113
def train_image (self ):
113
114
return IMAGE_NAME
114
115
115
- def create_model (self , role = None , model_server_workers = None , entry_point = None ):
116
+ def create_model (
117
+ self ,
118
+ role = None ,
119
+ model_server_workers = None ,
120
+ entry_point = None ,
121
+ vpc_config_override = vpc_utils .VPC_CONFIG_DEFAULT ,
122
+ ):
116
123
return DummyFrameworkModel (
117
- self .sagemaker_session , vpc_config = self .get_vpc_config (), entry_point = entry_point
124
+ self .sagemaker_session ,
125
+ vpc_config = self .get_vpc_config (vpc_config_override ),
126
+ entry_point = entry_point ,
127
+ enable_network_isolation = self .enable_network_isolation (),
128
+ role = role ,
118
129
)
119
130
120
131
@classmethod
@@ -127,12 +138,11 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
127
138
128
139
129
140
class DummyFrameworkModel (FrameworkModel ):
130
- def __init__ (self , sagemaker_session , entry_point = None , ** kwargs ):
141
+ def __init__ (self , sagemaker_session , entry_point = None , role = ROLE , ** kwargs ):
131
142
super (DummyFrameworkModel , self ).__init__ (
132
143
MODEL_DATA ,
133
144
MODEL_IMAGE ,
134
- INSTANCE_TYPE ,
135
- ROLE ,
145
+ role ,
136
146
entry_point or ENTRY_POINT ,
137
147
sagemaker_session = sagemaker_session ,
138
148
** kwargs
@@ -141,7 +151,7 @@ def __init__(self, sagemaker_session, entry_point=None, **kwargs):
141
151
def create_predictor (self , endpoint_name ):
142
152
return None
143
153
144
- def prepare_container_def (self , instance_type ):
154
+ def prepare_container_def (self , instance_type , accelerator_type = None ):
145
155
return MODEL_CONTAINER_DEF
146
156
147
157
@@ -1280,22 +1290,30 @@ def test_init_with_source_dir_s3(strftime, sagemaker_session):
1280
1290
assert fw ._hyperparameters == expected_hyperparameters
1281
1291
1282
1292
1283
- @patch ("sagemaker.estimator .name_from_image" , return_value = MODEL_IMAGE )
1293
+ @patch ("sagemaker.model.utils .name_from_image" , return_value = MODEL_IMAGE )
1284
1294
def test_framework_transformer_creation (name_from_image , sagemaker_session ):
1295
+ vpc_config = {"Subnets" : ["foo" ], "SecurityGroupIds" : ["bar" ]}
1285
1296
fw = DummyFramework (
1286
1297
entry_point = SCRIPT_PATH ,
1287
1298
role = ROLE ,
1288
1299
train_instance_count = INSTANCE_COUNT ,
1289
1300
train_instance_type = INSTANCE_TYPE ,
1290
1301
sagemaker_session = sagemaker_session ,
1302
+ subnets = vpc_config ["Subnets" ],
1303
+ security_group_ids = vpc_config ["SecurityGroupIds" ],
1291
1304
)
1292
1305
fw .latest_training_job = _TrainingJob (sagemaker_session , JOB_NAME )
1293
1306
1294
1307
transformer = fw .transformer (INSTANCE_COUNT , INSTANCE_TYPE )
1295
1308
1296
1309
name_from_image .assert_called_with (MODEL_IMAGE )
1297
1310
sagemaker_session .create_model .assert_called_with (
1298
- MODEL_IMAGE , ROLE , MODEL_CONTAINER_DEF , None , tags = None
1311
+ MODEL_IMAGE ,
1312
+ ROLE ,
1313
+ MODEL_CONTAINER_DEF ,
1314
+ tags = None ,
1315
+ vpc_config = vpc_config ,
1316
+ enable_network_isolation = False ,
1299
1317
)
1300
1318
1301
1319
assert isinstance (transformer , Transformer )
@@ -1307,7 +1325,7 @@ def test_framework_transformer_creation(name_from_image, sagemaker_session):
1307
1325
assert transformer .env == {}
1308
1326
1309
1327
1310
- @patch ("sagemaker.estimator .name_from_image" , return_value = MODEL_IMAGE )
1328
+ @patch ("sagemaker.model.utils .name_from_image" , return_value = MODEL_IMAGE )
1311
1329
def test_framework_transformer_creation_with_optional_params (name_from_image , sagemaker_session ):
1312
1330
base_name = "foo"
1313
1331
vpc_config = {"Subnets" : ["foo" ], "SecurityGroupIds" : ["bar" ]}
@@ -1320,6 +1338,7 @@ def test_framework_transformer_creation_with_optional_params(name_from_image, sa
1320
1338
base_job_name = base_name ,
1321
1339
subnets = vpc_config ["Subnets" ],
1322
1340
security_group_ids = vpc_config ["SecurityGroupIds" ],
1341
+ enable_network_isolation = True ,
1323
1342
)
1324
1343
fw .latest_training_job = _TrainingJob (sagemaker_session , JOB_NAME )
1325
1344
@@ -1331,6 +1350,7 @@ def test_framework_transformer_creation_with_optional_params(name_from_image, sa
1331
1350
max_payload = 6
1332
1351
env = {"FOO" : "BAR" }
1333
1352
new_role = "dummy-model-role"
1353
+ new_vpc_config = {"Subnets" : ["x" ], "SecurityGroupIds" : ["y" ]}
1334
1354
1335
1355
transformer = fw .transformer (
1336
1356
INSTANCE_COUNT ,
@@ -1347,10 +1367,16 @@ def test_framework_transformer_creation_with_optional_params(name_from_image, sa
1347
1367
env = env ,
1348
1368
role = new_role ,
1349
1369
model_server_workers = 1 ,
1370
+ vpc_config_override = new_vpc_config ,
1350
1371
)
1351
1372
1352
1373
sagemaker_session .create_model .assert_called_with (
1353
- MODEL_IMAGE , new_role , MODEL_CONTAINER_DEF , vpc_config , tags = TAGS
1374
+ MODEL_IMAGE ,
1375
+ new_role ,
1376
+ MODEL_CONTAINER_DEF ,
1377
+ vpc_config = new_vpc_config ,
1378
+ tags = TAGS ,
1379
+ enable_network_isolation = True ,
1354
1380
)
1355
1381
assert transformer .strategy == strategy
1356
1382
assert transformer .assemble_with == assemble_with
0 commit comments