2121import pytest
2222from mock import ANY , MagicMock , Mock , patch
2323
24+ from sagemaker import vpc_utils
2425from sagemaker .amazon .amazon_estimator import registry
2526from sagemaker .algorithm import AlgorithmEstimator
2627from sagemaker .estimator import Estimator , EstimatorBase , Framework , _TrainingJob
@@ -112,9 +113,19 @@ class DummyFramework(Framework):
112113 def train_image (self ):
113114 return IMAGE_NAME
114115
115- def create_model (self , role = None , model_server_workers = None , entry_point = None ):
116+ def create_model (
117+ self ,
118+ role = None ,
119+ model_server_workers = None ,
120+ entry_point = None ,
121+ vpc_config_override = vpc_utils .VPC_CONFIG_DEFAULT ,
122+ ):
116123 return DummyFrameworkModel (
117- self .sagemaker_session , vpc_config = self .get_vpc_config (), entry_point = entry_point
124+ self .sagemaker_session ,
125+ vpc_config = self .get_vpc_config (vpc_config_override ),
126+ entry_point = entry_point ,
127+ enable_network_isolation = self .enable_network_isolation (),
128+ role = role ,
118129 )
119130
120131 @classmethod
@@ -127,12 +138,11 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
127138
128139
129140class DummyFrameworkModel (FrameworkModel ):
130- def __init__ (self , sagemaker_session , entry_point = None , ** kwargs ):
141+ def __init__ (self , sagemaker_session , entry_point = None , role = ROLE , ** kwargs ):
131142 super (DummyFrameworkModel , self ).__init__ (
132143 MODEL_DATA ,
133144 MODEL_IMAGE ,
134- INSTANCE_TYPE ,
135- ROLE ,
145+ role ,
136146 entry_point or ENTRY_POINT ,
137147 sagemaker_session = sagemaker_session ,
138148 ** kwargs
@@ -141,7 +151,7 @@ def __init__(self, sagemaker_session, entry_point=None, **kwargs):
141151 def create_predictor (self , endpoint_name ):
142152 return None
143153
144- def prepare_container_def (self , instance_type ):
154+ def prepare_container_def (self , instance_type , accelerator_type = None ):
145155 return MODEL_CONTAINER_DEF
146156
147157
@@ -1280,22 +1290,30 @@ def test_init_with_source_dir_s3(strftime, sagemaker_session):
12801290 assert fw ._hyperparameters == expected_hyperparameters
12811291
12821292
1283- @patch ("sagemaker.estimator .name_from_image" , return_value = MODEL_IMAGE )
1293+ @patch ("sagemaker.model.utils .name_from_image" , return_value = MODEL_IMAGE )
12841294def test_framework_transformer_creation (name_from_image , sagemaker_session ):
1295+ vpc_config = {"Subnets" : ["foo" ], "SecurityGroupIds" : ["bar" ]}
12851296 fw = DummyFramework (
12861297 entry_point = SCRIPT_PATH ,
12871298 role = ROLE ,
12881299 train_instance_count = INSTANCE_COUNT ,
12891300 train_instance_type = INSTANCE_TYPE ,
12901301 sagemaker_session = sagemaker_session ,
1302+ subnets = vpc_config ["Subnets" ],
1303+ security_group_ids = vpc_config ["SecurityGroupIds" ],
12911304 )
12921305 fw .latest_training_job = _TrainingJob (sagemaker_session , JOB_NAME )
12931306
12941307 transformer = fw .transformer (INSTANCE_COUNT , INSTANCE_TYPE )
12951308
12961309 name_from_image .assert_called_with (MODEL_IMAGE )
12971310 sagemaker_session .create_model .assert_called_with (
1298- MODEL_IMAGE , ROLE , MODEL_CONTAINER_DEF , None , tags = None
1311+ MODEL_IMAGE ,
1312+ ROLE ,
1313+ MODEL_CONTAINER_DEF ,
1314+ tags = None ,
1315+ vpc_config = vpc_config ,
1316+ enable_network_isolation = False ,
12991317 )
13001318
13011319 assert isinstance (transformer , Transformer )
@@ -1307,7 +1325,7 @@ def test_framework_transformer_creation(name_from_image, sagemaker_session):
13071325 assert transformer .env == {}
13081326
13091327
1310- @patch ("sagemaker.estimator .name_from_image" , return_value = MODEL_IMAGE )
1328+ @patch ("sagemaker.model.utils .name_from_image" , return_value = MODEL_IMAGE )
13111329def test_framework_transformer_creation_with_optional_params (name_from_image , sagemaker_session ):
13121330 base_name = "foo"
13131331 vpc_config = {"Subnets" : ["foo" ], "SecurityGroupIds" : ["bar" ]}
@@ -1320,6 +1338,7 @@ def test_framework_transformer_creation_with_optional_params(name_from_image, sa
13201338 base_job_name = base_name ,
13211339 subnets = vpc_config ["Subnets" ],
13221340 security_group_ids = vpc_config ["SecurityGroupIds" ],
1341+ enable_network_isolation = True ,
13231342 )
13241343 fw .latest_training_job = _TrainingJob (sagemaker_session , JOB_NAME )
13251344
@@ -1331,6 +1350,7 @@ def test_framework_transformer_creation_with_optional_params(name_from_image, sa
13311350 max_payload = 6
13321351 env = {"FOO" : "BAR" }
13331352 new_role = "dummy-model-role"
1353+ new_vpc_config = {"Subnets" : ["x" ], "SecurityGroupIds" : ["y" ]}
13341354
13351355 transformer = fw .transformer (
13361356 INSTANCE_COUNT ,
@@ -1347,10 +1367,16 @@ def test_framework_transformer_creation_with_optional_params(name_from_image, sa
13471367 env = env ,
13481368 role = new_role ,
13491369 model_server_workers = 1 ,
1370+ vpc_config_override = new_vpc_config ,
13501371 )
13511372
13521373 sagemaker_session .create_model .assert_called_with (
1353- MODEL_IMAGE , new_role , MODEL_CONTAINER_DEF , vpc_config , tags = TAGS
1374+ MODEL_IMAGE ,
1375+ new_role ,
1376+ MODEL_CONTAINER_DEF ,
1377+ vpc_config = new_vpc_config ,
1378+ tags = TAGS ,
1379+ enable_network_isolation = True ,
13541380 )
13551381 assert transformer .strategy == strategy
13561382 assert transformer .assemble_with == assemble_with
0 commit comments