3131BUCKET_NAME = 'mybucket'
3232INSTANCE_COUNT = 1
3333INSTANCE_TYPE = 'ml.c4.4xlarge'
34- CPU_IMAGE_NAME = 'sagemaker-tensorflow-py2-cpu'
35- GPU_IMAGE_NAME = 'sagemaker-tensorflow-py2-gpu'
36- JOB_NAME = '{}-{}' .format (CPU_IMAGE_NAME , TIMESTAMP )
34+ IMAGE_REPO_NAME = 'sagemaker-tensorflow'
35+ JOB_NAME = '{}-{}' .format (IMAGE_REPO_NAME , TIMESTAMP )
3736ROLE = 'Dummy'
3837REGION = 'us-west-2'
3938DOCKER_TAG = '1.0'
@@ -53,11 +52,11 @@ def sagemaker_session():
5352
5453
5554def _get_full_cpu_image_uri (version ):
56- return IMAGE_URI_FORMAT_STRING .format (REGION , CPU_IMAGE_NAME , version , 'cpu' , 'py2' )
55+ return IMAGE_URI_FORMAT_STRING .format (REGION , IMAGE_REPO_NAME , version , 'cpu' , 'py2' )
5756
5857
5958def _get_full_gpu_image_uri (version ):
60- return IMAGE_URI_FORMAT_STRING .format (REGION , GPU_IMAGE_NAME , version , 'gpu' , 'py2' )
59+ return IMAGE_URI_FORMAT_STRING .format (REGION , IMAGE_REPO_NAME , version , 'gpu' , 'py2' )
6160
6261
6362def _create_train_job (tf_version ):
@@ -231,11 +230,11 @@ def test_tf(time, strftime, sagemaker_session, tf_version):
231230 'SAGEMAKER_REGION' : 'us-west-2' ,
232231 'SAGEMAKER_CONTAINER_LOG_LEVEL' : '20'
233232 },
234- 'Image' : create_image_uri ('us-west-2' , "tensorflow" , GPU_IMAGE_NAME , tf_version , "py2" ),
235- 'ModelDataUrl' : 's3://m/m.tar.gz' } == model .prepare_container_def (GPU_IMAGE_NAME )
233+ 'Image' : create_image_uri ('us-west-2' , "tensorflow" , INSTANCE_TYPE , tf_version , "py2" ),
234+ 'ModelDataUrl' : 's3://m/m.tar.gz' } == model .prepare_container_def (INSTANCE_TYPE )
236235
237- assert 'cpu' in model .prepare_container_def (CPU_IMAGE_NAME )['Image' ]
238- predictor = tf .deploy (1 , GPU_IMAGE_NAME )
236+ assert 'cpu' in model .prepare_container_def (INSTANCE_TYPE )['Image' ]
237+ predictor = tf .deploy (1 , INSTANCE_TYPE )
239238 assert isinstance (predictor , TensorFlowPredictor )
240239
241240
@@ -257,7 +256,7 @@ def test_run_tensorboard_locally_without_tensorboard_binary(time, strftime, pope
257256def test_model (sagemaker_session , tf_version ):
258257 model = TensorFlowModel ("s3://some/data.tar.gz" , role = ROLE , entry_point = SCRIPT_PATH ,
259258 sagemaker_session = sagemaker_session )
260- predictor = model .deploy (1 , GPU_IMAGE_NAME )
259+ predictor = model .deploy (1 , INSTANCE_TYPE )
261260 assert isinstance (predictor , TensorFlowPredictor )
262261
263262
@@ -410,6 +409,54 @@ def test_attach(sagemaker_session, tf_version):
410409 assert estimator .checkpoint_path == 's3://other/1508872349'
411410
412411
412+ def test_attach_new_repo_name (sagemaker_session , tf_version ):
413+ training_image = '520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow:{}-cpu-py2' .format (tf_version )
414+ rjd = {'AlgorithmSpecification' :
415+ {'TrainingInputMode' : 'File' ,
416+ 'TrainingImage' : training_image },
417+ 'HyperParameters' :
418+ {'sagemaker_submit_directory' : '"s3://some/sourcedir.tar.gz"' ,
419+ 'checkpoint_path' : '"s3://other/1508872349"' ,
420+ 'sagemaker_program' : '"iris-dnn-classifier.py"' ,
421+ 'sagemaker_enable_cloudwatch_metrics' : 'false' ,
422+ 'sagemaker_container_log_level' : '"logging.INFO"' ,
423+ 'sagemaker_job_name' : '"neo"' ,
424+ 'training_steps' : '100' ,
425+ 'evaluation_steps' : '10' },
426+ 'RoleArn' : 'arn:aws:iam::366:role/SageMakerRole' ,
427+ 'ResourceConfig' :
428+ {'VolumeSizeInGB' : 30 ,
429+ 'InstanceCount' : 1 ,
430+ 'InstanceType' : 'ml.c4.xlarge' },
431+ 'StoppingCondition' : {'MaxRuntimeInSeconds' : 24 * 60 * 60 },
432+ 'TrainingJobName' : 'neo' ,
433+ 'TrainingJobStatus' : 'Completed' ,
434+ 'OutputDataConfig' : {'KmsKeyId' : '' ,
435+ 'S3OutputPath' : 's3://place/output/neo' },
436+ 'TrainingJobOutput' : {'S3TrainingJobOutput' : 's3://here/output.tar.gz' }}
437+ sagemaker_session .sagemaker_client .describe_training_job = Mock (name = 'describe_training_job' , return_value = rjd )
438+
439+ estimator = TensorFlow .attach (training_job_name = 'neo' , sagemaker_session = sagemaker_session )
440+ assert estimator .latest_training_job .job_name == 'neo'
441+ assert estimator .py_version == 'py2'
442+ assert estimator .framework_version == tf_version
443+ assert estimator .role == 'arn:aws:iam::366:role/SageMakerRole'
444+ assert estimator .train_instance_count == 1
445+ assert estimator .train_max_run == 24 * 60 * 60
446+ assert estimator .input_mode == 'File'
447+ assert estimator .training_steps == 100
448+ assert estimator .evaluation_steps == 10
449+ assert estimator .input_mode == 'File'
450+ assert estimator .base_job_name == 'neo'
451+ assert estimator .output_path == 's3://place/output/neo'
452+ assert estimator .output_kms_key == ''
453+ assert estimator .hyperparameters ()['training_steps' ] == '100'
454+ assert estimator .source_dir == 's3://some/sourcedir.tar.gz'
455+ assert estimator .entry_point == 'iris-dnn-classifier.py'
456+ assert estimator .checkpoint_path == 's3://other/1508872349'
457+ assert estimator .train_image () == training_image
458+
459+
413460def test_attach_old_container (sagemaker_session ):
414461 training_image = '1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow-py2-cpu:1.0'
415462 rjd = {'AlgorithmSpecification' :
0 commit comments