2727
2828from sagemaker .local import LocalSession , LocalSagemakerRuntimeClient , LocalSagemakerClient
2929from sagemaker .mxnet import MXNet
30- from sagemaker .tensorflow import estimator
30+ from sagemaker .tensorflow import TensorFlow
3131
3232# endpoint tests all use the same port, so we use this lock to prevent concurrent execution
3333LOCK_PATH = os .path .join (tempfile .gettempdir (), "sagemaker_test_local_mode_lock" )
@@ -90,7 +90,7 @@ def test_tf_local_mode(sagemaker_local_session):
9090 with stopit .ThreadingTimeout (5 * 60 , swallow_exc = False ):
9191 script_path = os .path .join (DATA_DIR , "iris" , "iris-dnn-classifier.py" )
9292
93- tensorflow_estimator = estimator . TensorFlow (
93+ estimator = TensorFlow (
9494 entry_point = script_path ,
9595 role = "SageMakerRole" ,
9696 framework_version = "1.12" ,
@@ -103,16 +103,16 @@ def test_tf_local_mode(sagemaker_local_session):
103103 sagemaker_session = sagemaker_local_session ,
104104 )
105105
106- inputs = tensorflow_estimator .sagemaker_session .upload_data (
106+ inputs = estimator .sagemaker_session .upload_data (
107107 path = DATA_PATH , key_prefix = "integ-test-data/tf_iris"
108108 )
109- tensorflow_estimator .fit (inputs )
110- print ("job succeeded: {}" .format (tensorflow_estimator .latest_training_job .name ))
109+ estimator .fit (inputs )
110+ print ("job succeeded: {}" .format (estimator .latest_training_job .name ))
111111
112- endpoint_name = tensorflow_estimator .latest_training_job .name
112+ endpoint_name = estimator .latest_training_job .name
113113 with lock .lock (LOCK_PATH ):
114114 try :
115- json_predictor = tensorflow_estimator .deploy (
115+ json_predictor = estimator .deploy (
116116 initial_instance_count = 1 , instance_type = "local" , endpoint_name = endpoint_name
117117 )
118118
@@ -124,7 +124,7 @@ def test_tf_local_mode(sagemaker_local_session):
124124
125125 assert dict_result == list_result
126126 finally :
127- tensorflow_estimator .delete_endpoint ()
127+ estimator .delete_endpoint ()
128128
129129
130130@pytest .mark .local_mode
@@ -133,7 +133,7 @@ def test_tf_distributed_local_mode(sagemaker_local_session):
133133 with stopit .ThreadingTimeout (5 * 60 , swallow_exc = False ):
134134 script_path = os .path .join (DATA_DIR , "iris" , "iris-dnn-classifier.py" )
135135
136- tensorflow_estimator = estimator . TensorFlow (
136+ estimator = TensorFlow (
137137 entry_point = script_path ,
138138 role = "SageMakerRole" ,
139139 framework_version = "1.12" ,
@@ -147,14 +147,14 @@ def test_tf_distributed_local_mode(sagemaker_local_session):
147147 )
148148
149149 inputs = "file://" + DATA_PATH
150- tensorflow_estimator .fit (inputs )
151- print ("job succeeded: {}" .format (tensorflow_estimator .latest_training_job .name ))
150+ estimator .fit (inputs )
151+ print ("job succeeded: {}" .format (estimator .latest_training_job .name ))
152152
153- endpoint_name = tensorflow_estimator .latest_training_job .name
153+ endpoint_name = estimator .latest_training_job .name
154154
155155 with lock .lock (LOCK_PATH ):
156156 try :
157- json_predictor = tensorflow_estimator .deploy (
157+ json_predictor = estimator .deploy (
158158 initial_instance_count = 1 , instance_type = "local" , endpoint_name = endpoint_name
159159 )
160160
@@ -166,7 +166,7 @@ def test_tf_distributed_local_mode(sagemaker_local_session):
166166
167167 assert dict_result == list_result
168168 finally :
169- tensorflow_estimator .delete_endpoint ()
169+ estimator .delete_endpoint ()
170170
171171
172172@pytest .mark .local_mode
@@ -175,7 +175,7 @@ def test_tf_local_data(sagemaker_local_session):
175175 with stopit .ThreadingTimeout (5 * 60 , swallow_exc = False ):
176176 script_path = os .path .join (DATA_DIR , "iris" , "iris-dnn-classifier.py" )
177177
178- tensorflow_estimator = estimator . TensorFlow (
178+ estimator = TensorFlow (
179179 entry_point = script_path ,
180180 role = "SageMakerRole" ,
181181 framework_version = "1.12" ,
@@ -189,13 +189,13 @@ def test_tf_local_data(sagemaker_local_session):
189189 )
190190
191191 inputs = "file://" + DATA_PATH
192- tensorflow_estimator .fit (inputs )
193- print ("job succeeded: {}" .format (tensorflow_estimator .latest_training_job .name ))
192+ estimator .fit (inputs )
193+ print ("job succeeded: {}" .format (estimator .latest_training_job .name ))
194194
195- endpoint_name = tensorflow_estimator .latest_training_job .name
195+ endpoint_name = estimator .latest_training_job .name
196196 with lock .lock (LOCK_PATH ):
197197 try :
198- json_predictor = tensorflow_estimator .deploy (
198+ json_predictor = estimator .deploy (
199199 initial_instance_count = 1 , instance_type = "local" , endpoint_name = endpoint_name
200200 )
201201
@@ -207,7 +207,7 @@ def test_tf_local_data(sagemaker_local_session):
207207
208208 assert dict_result == list_result
209209 finally :
210- tensorflow_estimator .delete_endpoint ()
210+ estimator .delete_endpoint ()
211211
212212
213213@pytest .mark .local_mode
@@ -216,7 +216,7 @@ def test_tf_local_data_local_script():
216216 with stopit .ThreadingTimeout (5 * 60 , swallow_exc = False ):
217217 script_path = os .path .join (DATA_DIR , "iris" , "iris-dnn-classifier.py" )
218218
219- tensorflow_estimator = estimator . TensorFlow (
219+ estimator = TensorFlow (
220220 entry_point = script_path ,
221221 role = "SageMakerRole" ,
222222 framework_version = "1.12" ,
@@ -231,13 +231,13 @@ def test_tf_local_data_local_script():
231231
232232 inputs = "file://" + DATA_PATH
233233
234- tensorflow_estimator .fit (inputs )
235- print ("job succeeded: {}" .format (tensorflow_estimator .latest_training_job .name ))
234+ estimator .fit (inputs )
235+ print ("job succeeded: {}" .format (estimator .latest_training_job .name ))
236236
237- endpoint_name = tensorflow_estimator .latest_training_job .name
237+ endpoint_name = estimator .latest_training_job .name
238238 with lock .lock (LOCK_PATH ):
239239 try :
240- json_predictor = tensorflow_estimator .deploy (
240+ json_predictor = estimator .deploy (
241241 initial_instance_count = 1 , instance_type = "local" , endpoint_name = endpoint_name
242242 )
243243
@@ -249,7 +249,7 @@ def test_tf_local_data_local_script():
249249
250250 assert dict_result == list_result
251251 finally :
252- tensorflow_estimator .delete_endpoint ()
252+ estimator .delete_endpoint ()
253253
254254
255255@pytest .mark .local_mode
0 commit comments