1010# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111# ANY KIND, either express or implied. See the License for the specific
1212# language governing permissions and limitations under the License.
13- import logging
14-
1513import json
14+ import logging
1615import os
16+
1717import pytest
1818from mock import Mock , patch
19+
20+ from sagemaker .fw_utils import create_image_uri
1921from sagemaker .model import MODEL_SERVER_WORKERS_PARAM_NAME
2022from sagemaker .session import s3_input
21- from sagemaker .tensorflow import TensorFlow
22- from sagemaker .tensorflow import defaults
23- from sagemaker .fw_utils import create_image_uri
24- from sagemaker .tensorflow import TensorFlowPredictor , TensorFlowModel
23+ from sagemaker .tensorflow import defaults , TensorFlow , TensorFlowPredictor , TensorFlowModel
2524
2625DATA_DIR = os .path .join (os .path .dirname (__file__ ), '..' , 'data' )
27- SCRIPT_PATH = os .path .join (DATA_DIR , 'dummy_script.py' )
26+ SCRIPT_FILE = 'dummy_script.py'
27+ SCRIPT_PATH = os .path .join (DATA_DIR , SCRIPT_FILE )
28+ REQUIREMENTS_FILE = 'dummy_requirements.txt'
2829TIMESTAMP = '2017-11-06-14:14:15.673'
2930TIME = 1510006209.073025
3031BUCKET_NAME = 'mybucket'
@@ -85,6 +86,7 @@ def _create_train_job(tf_version):
8586 'training_steps' : '1000' ,
8687 'evaluation_steps' : '10' ,
8788 'sagemaker_program' : json .dumps ('dummy_script.py' ),
89+ 'sagemaker_requirements' : '"{}"' .format (REQUIREMENTS_FILE ),
8890 'sagemaker_submit_directory' : json .dumps ('s3://{}/{}/source/sourcedir.tar.gz' .format (
8991 BUCKET_NAME , JOB_NAME )),
9092 'sagemaker_enable_cloudwatch_metrics' : 'false' ,
@@ -100,10 +102,10 @@ def _create_train_job(tf_version):
100102
101103def _build_tf (sagemaker_session , framework_version = defaults .TF_VERSION , train_instance_type = None ,
102104 checkpoint_path = None , enable_cloudwatch_metrics = False , base_job_name = None ,
103- training_steps = None , evalutation_steps = None , ** kwargs ):
105+ training_steps = None , evaluation_steps = None , ** kwargs ):
104106 return TensorFlow (entry_point = SCRIPT_PATH ,
105107 training_steps = training_steps ,
106- evaluation_steps = evalutation_steps ,
108+ evaluation_steps = evaluation_steps ,
107109 framework_version = framework_version ,
108110 role = ROLE ,
109111 sagemaker_session = sagemaker_session ,
@@ -158,6 +160,20 @@ def test_tf_deploy_model_server_workers_unset(sagemaker_session):
158160 assert MODEL_SERVER_WORKERS_PARAM_NAME .upper () not in sagemaker_session .method_calls [3 ][1 ][2 ]['Environment' ]
159161
160162
163+ def test_tf_invalid_requirements_path (sagemaker_session ):
164+ requirements_file = '/foo/bar/requirements.txt'
165+ with pytest .raises (ValueError ) as e :
166+ _build_tf (sagemaker_session , requirements_file = requirements_file , source_dir = DATA_DIR )
167+ assert 'Requirements file {} is not a path relative to source_dir.' .format (requirements_file ) in str (e .value )
168+
169+
170+ def test_tf_nonexistent_requirements_path (sagemaker_session ):
171+ requirements_file = 'nonexistent_requirements.txt'
172+ with pytest .raises (ValueError ) as e :
173+ _build_tf (sagemaker_session , requirements_file = requirements_file , source_dir = DATA_DIR )
174+ assert 'Requirements file {} does not exist.' .format (requirements_file ) in str (e .value )
175+
176+
161177def test_create_model (sagemaker_session , tf_version ):
162178 container_log_level = '"logging.INFO"'
163179 source_dir = 's3://mybucket/source'
@@ -186,9 +202,9 @@ def test_create_model(sagemaker_session, tf_version):
186202@patch ('time.strftime' , return_value = TIMESTAMP )
187203@patch ('time.time' , return_value = TIME )
188204def test_tf (time , strftime , sagemaker_session , tf_version ):
189- tf = TensorFlow (entry_point = SCRIPT_PATH , role = ROLE , sagemaker_session = sagemaker_session ,
190- training_steps = 1000 , evaluation_steps = 10 , train_instance_count = INSTANCE_COUNT ,
191- train_instance_type = INSTANCE_TYPE , framework_version = tf_version )
205+ tf = TensorFlow (entry_point = SCRIPT_FILE , role = ROLE , sagemaker_session = sagemaker_session , training_steps = 1000 ,
206+ evaluation_steps = 10 , train_instance_count = INSTANCE_COUNT , train_instance_type = INSTANCE_TYPE ,
207+ framework_version = tf_version , requirements_file = REQUIREMENTS_FILE , source_dir = DATA_DIR )
192208
193209 inputs = 's3://mybucket/train'
194210
@@ -210,6 +226,7 @@ def test_tf(time, strftime, sagemaker_session, tf_version):
210226 assert {'Environment' :
211227 {'SAGEMAKER_SUBMIT_DIRECTORY' : 's3://{}/{}/sourcedir.tar.gz' .format (BUCKET_NAME , JOB_NAME ),
212228 'SAGEMAKER_PROGRAM' : 'dummy_script.py' ,
229+ 'SAGEMAKER_REQUIREMENTS' : 'dummy_requirements.txt' ,
213230 'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS' : 'false' ,
214231 'SAGEMAKER_REGION' : 'us-west-2' ,
215232 'SAGEMAKER_CONTAINER_LOG_LEVEL' : '20'
@@ -315,7 +332,7 @@ def test_tf_training_and_evaluation_steps_not_set(sagemaker_session):
315332 job_name = "sagemaker-tensorflow-py2-gpu-2017-10-24-14-12-09"
316333 output_path = "s3://{}/output/{}/" .format (sagemaker_session .default_bucket (), job_name )
317334
318- tf = _build_tf (sagemaker_session , training_steps = None , evalutation_steps = None , output_path = output_path )
335+ tf = _build_tf (sagemaker_session , training_steps = None , evaluation_steps = None , output_path = output_path )
319336 tf .fit (inputs = s3_input ('s3://mybucket/train' ))
320337 assert tf .hyperparameters ()['training_steps' ] == 'null'
321338 assert tf .hyperparameters ()['evaluation_steps' ] == 'null'
@@ -325,7 +342,7 @@ def test_tf_training_and_evaluation_steps(sagemaker_session):
325342 job_name = "sagemaker-tensorflow-py2-gpu-2017-10-24-14-12-09"
326343 output_path = "s3://{}/output/{}/" .format (sagemaker_session .default_bucket (), job_name )
327344
328- tf = _build_tf (sagemaker_session , training_steps = 123 , evalutation_steps = 456 , output_path = output_path )
345+ tf = _build_tf (sagemaker_session , training_steps = 123 , evaluation_steps = 456 , output_path = output_path )
329346 tf .fit (inputs = s3_input ('s3://mybucket/train' ))
330347 assert tf .hyperparameters ()['training_steps' ] == '123'
331348 assert tf .hyperparameters ()['evaluation_steps' ] == '456'
0 commit comments