1212# language governing permissions and limitations under the License.
1313from __future__ import absolute_import
1414
15+ import numpy as np
1516import os
1617import pytest
18+ import time
1719
1820import boto3
1921from sagemaker .tensorflow import TensorFlow
@@ -40,7 +42,7 @@ def test_mnist(sagemaker_session, instance_type):
4042 train_instance_type = instance_type ,
4143 sagemaker_session = sagemaker_session ,
4244 py_version = 'py3' ,
43- framework_version = '1.11' ,
45+ framework_version = TensorFlow . LATEST_VERSION ,
4446 base_job_name = 'test-tf-sm-mnist' )
4547 inputs = estimator .sagemaker_session .upload_data (
4648 path = os .path .join (RESOURCE_PATH , 'data' ),
@@ -49,7 +51,7 @@ def test_mnist(sagemaker_session, instance_type):
4951 with timeout .timeout (minutes = integ .TRAINING_DEFAULT_TIMEOUT_MINUTES ):
5052 estimator .fit (inputs )
5153 _assert_s3_files_exist (estimator .model_dir ,
52- ['graph.pbtxt' , 'model.ckpt-0.index' , 'model.ckpt-0.meta' , 'saved_model.pb' ])
54+ ['graph.pbtxt' , 'model.ckpt-0.index' , 'model.ckpt-0.meta' ])
5355
5456
5557@pytest .mark .canary_quick
@@ -63,7 +65,7 @@ def test_mnist_distributed(sagemaker_session, instance_type):
6365 sagemaker_session = sagemaker_session ,
6466 py_version = integ .PYTHON_VERSION ,
6567 script_mode = True ,
66- framework_version = '1.11' ,
68+ framework_version = TensorFlow . LATEST_VERSION ,
6769 distributions = PARAMETER_SERVER_DISTRIBUTION ,
6870 base_job_name = 'test-tf-sm-mnist' )
6971 inputs = estimator .sagemaker_session .upload_data (
@@ -73,7 +75,32 @@ def test_mnist_distributed(sagemaker_session, instance_type):
7375 with timeout .timeout (minutes = integ .TRAINING_DEFAULT_TIMEOUT_MINUTES ):
7476 estimator .fit (inputs )
7577 _assert_s3_files_exist (estimator .model_dir ,
76- ['graph.pbtxt' , 'model.ckpt-0.index' , 'model.ckpt-0.meta' , 'saved_model.pb' ])
78+ ['graph.pbtxt' , 'model.ckpt-0.index' , 'model.ckpt-0.meta' ])
79+
80+
81+ def test_mnist_async (sagemaker_session ):
82+ estimator = TensorFlow (entry_point = SCRIPT ,
83+ role = 'SageMakerRole' ,
84+ train_instance_count = 1 ,
85+ train_instance_type = 'ml.c5.4xlarge' ,
86+ sagemaker_session = sagemaker_session ,
87+ py_version = 'py3' ,
88+ framework_version = TensorFlow .LATEST_VERSION ,
89+ base_job_name = 'test-tf-sm-mnist' )
90+ inputs = estimator .sagemaker_session .upload_data (
91+ path = os .path .join (RESOURCE_PATH , 'data' ),
92+ key_prefix = 'scriptmode/mnist' )
93+ estimator .fit (inputs , wait = False )
94+ training_job_name = estimator .latest_training_job .name
95+ time .sleep (20 )
96+ endpoint_name = training_job_name
97+ with timeout .timeout_and_delete_endpoint_by_name (endpoint_name , sagemaker_session ):
98+ estimator = TensorFlow .attach (training_job_name = training_job_name , sagemaker_session = sagemaker_session )
99+ predictor = estimator .deploy (initial_instance_count = 1 , instance_type = 'ml.c4.xlarge' ,
100+ endpoint_name = endpoint_name )
101+
102+ result = predictor .predict (np .zeros (784 ))
103+ print ('predict result: {}' .format (result ))
77104
78105
79106def _assert_s3_files_exist (s3_url , files ):
0 commit comments