2222from sagemaker .fw_utils import create_image_uri
2323from sagemaker .model import MODEL_SERVER_WORKERS_PARAM_NAME
2424from sagemaker .session import s3_input
25- from sagemaker .tensorflow import defaults , TensorFlow , TensorFlowModel , TensorFlowPredictor
25+ from sagemaker .tensorflow import defaults , serving , TensorFlow , TensorFlowModel , TensorFlowPredictor
2626import sagemaker .tensorflow .estimator as tfe
2727
2828DATA_DIR = os .path .join (os .path .dirname (__file__ ), ".." , "data" )
@@ -256,6 +256,7 @@ def test_create_model(sagemaker_session, tf_version):
256256 container_log_level = container_log_level ,
257257 base_job_name = "job" ,
258258 source_dir = source_dir ,
259+ enable_network_isolation = True ,
259260 )
260261
261262 job_name = "doing something"
@@ -271,6 +272,7 @@ def test_create_model(sagemaker_session, tf_version):
271272 assert model .container_log_level == container_log_level
272273 assert model .source_dir == source_dir
273274 assert model .vpc_config is None
275+ assert model .enable_network_isolation ()
274276
275277
276278def test_create_model_with_optional_params (sagemaker_session ):
@@ -1002,11 +1004,23 @@ def test_script_mode_enabled(sagemaker_session):
10021004 assert tf ._script_mode_enabled () is False
10031005
10041006
1005- @patch ("sagemaker.tensorflow.estimator.TensorFlow._create_tfs_model" )
1006- def test_script_mode_create_model (create_tfs_model , sagemaker_session ):
1007- tf = _build_tf (sagemaker_session = sagemaker_session , py_version = "py3" )
1008- tf .create_model ()
1009- create_tfs_model .assert_called_once ()
1007+ def test_script_mode_create_model (sagemaker_session ):
1008+ tf = _build_tf (
1009+ sagemaker_session = sagemaker_session , py_version = "py3" , enable_network_isolation = True
1010+ )
1011+ tf ._prepare_for_training () # set output_path and job name as if training happened
1012+
1013+ model = tf .create_model ()
1014+
1015+ assert isinstance (model , serving .Model )
1016+
1017+ assert model .model_data == tf .model_data
1018+ assert model .role == tf .role
1019+ assert model .name == tf ._current_job_name
1020+ assert model .container_log_level == tf .container_log_level
1021+ assert model ._framework_version == "1.11"
1022+ assert model .sagemaker_session == sagemaker_session
1023+ assert model .enable_network_isolation ()
10101024
10111025
10121026@patch ("sagemaker.utils.create_tar_file" , MagicMock ())
0 commit comments