1313from __future__ import absolute_import
1414
1515import pytest
16- from mock import Mock
16+ from mock import Mock , ANY
1717from sagemaker .tensorflow import TensorFlow
1818
1919
2424INSTANCE_COUNT = 1
2525INSTANCE_TYPE_GPU = "ml.p2.xlarge"
2626INSTANCE_TYPE_CPU = "ml.m4.xlarge"
27- CPU_IMAGE_NAME = "sagemaker- tensorflow-py2-cpu "
28- GPU_IMAGE_NAME = "sagemaker-tensorflow-py2-gpu "
27+ REPOSITORY = "tensorflow-inference "
28+ PROCESSOR = "cpu "
2929REGION = "us-west-2"
30- IMAGE_URI_FORMAT_STRING = "520713654638 .dkr.ecr.{}.amazonaws.com/{}:{}- {}-{}"
30+ IMAGE_URI_FORMAT_STRING = "763104351884 .dkr.ecr.{}.amazonaws.com/{}:{}-{}"
3131REGION = "us-west-2"
3232ROLE = "SagemakerRole"
3333SOURCE_DIR = "s3://fefergerger"
34+ ENDPOINT_DESC = {"EndpointConfigName" : "test-endpoint" }
35+ ENDPOINT_CONFIG_DESC = {"ProductionVariants" : [{"ModelName" : "model-1" }, {"ModelName" : "model-2" }]}
3436
3537
3638@pytest .fixture ()
@@ -39,48 +41,64 @@ def sagemaker_session():
3941 ims = Mock (
4042 name = "sagemaker_session" ,
4143 boto_session = boto_mock ,
44+ boto_region_name = REGION ,
4245 config = None ,
4346 local_mode = False ,
44- region_name = REGION ,
47+ s3_resource = None ,
48+ s3_client = None ,
4549 )
4650 ims .default_bucket = Mock (name = "default_bucket" , return_value = BUCKET_NAME )
4751 ims .expand_role = Mock (name = "expand_role" , return_value = ROLE )
4852 ims .sagemaker_client .describe_training_job = Mock (
4953 return_value = {"ModelArtifacts" : {"S3ModelArtifacts" : "s3://m/m.tar.gz" }}
5054 )
55+ ims .sagemaker_client .describe_endpoint = Mock (return_value = ENDPOINT_DESC )
56+ ims .sagemaker_client .describe_endpoint_config = Mock (return_value = ENDPOINT_CONFIG_DESC )
5157 return ims
5258
5359
60+ def test_model_dir_false (sagemaker_session ):
61+ estimator = TensorFlow (
62+ entry_point = SCRIPT ,
63+ source_dir = SOURCE_DIR ,
64+ role = ROLE ,
65+ framework_version = "2.3.0" ,
66+ py_version = "py37" ,
67+ instance_type = "ml.m4.xlarge" ,
68+ instance_count = 1 ,
69+ model_dir = False ,
70+ )
71+ estimator .hyperparameters ()
72+ assert estimator .model_dir is False
73+
74+
5475# Test that we pass all necessary fields from estimator to the session when we call deploy
55- def test_deploy (sagemaker_session , tf_version ):
76+ def test_deploy (sagemaker_session ):
5677 estimator = TensorFlow (
5778 entry_point = SCRIPT ,
5879 source_dir = SOURCE_DIR ,
5980 role = ROLE ,
60- framework_version = tf_version ,
61- train_instance_count = 2 ,
62- train_instance_type = INSTANCE_TYPE_CPU ,
81+ framework_version = "2.3.0" ,
82+ py_version = "py37" ,
83+ instance_count = 2 ,
84+ instance_type = INSTANCE_TYPE_CPU ,
6385 sagemaker_session = sagemaker_session ,
6486 base_job_name = "test-cifar" ,
6587 )
6688
6789 estimator .fit ("s3://mybucket/train" )
68- print ("job succeeded: {}" .format (estimator .latest_training_job .name ))
6990
7091 estimator .deploy (initial_instance_count = 1 , instance_type = INSTANCE_TYPE_CPU )
71- image = IMAGE_URI_FORMAT_STRING .format (REGION , CPU_IMAGE_NAME , tf_version , "cpu " , "py2" )
92+ image = IMAGE_URI_FORMAT_STRING .format (REGION , REPOSITORY , "2.3.0 " , PROCESSOR )
7293 sagemaker_session .create_model .assert_called_with (
73- estimator . _current_job_name ,
94+ ANY ,
7495 ROLE ,
7596 {
76- "Environment" : {
77- "SAGEMAKER_CONTAINER_LOG_LEVEL" : "20" ,
78- "SAGEMAKER_SUBMIT_DIRECTORY" : SOURCE_DIR ,
79- "SAGEMAKER_REQUIREMENTS" : "" ,
80- "SAGEMAKER_REGION" : REGION ,
81- "SAGEMAKER_PROGRAM" : SCRIPT ,
82- },
8397 "Image" : image ,
98+ "Environment" : {"SAGEMAKER_TFS_NGINX_LOGLEVEL" : "info" },
8499 "ModelDataUrl" : "s3://m/m.tar.gz" ,
85100 },
101+ vpc_config = None ,
102+ enable_network_isolation = False ,
103+ tags = None ,
86104 )
0 commit comments