1919import pytest
2020
2121from sagemaker .tensorflow import TensorFlow
22+ from sagemaker .tensorflow .defaults import LATEST_SERVING_VERSION
2223from sagemaker .utils import unique_name_from_base , sagemaker_timestamp
2324
2425import tests .integ
4041
4142
4243@pytest .fixture (scope = "module" )
43- def py_version (tf_full_version ):
44- return (
45- "py37" if tf_full_version == TensorFlow ._LATEST_1X_VERSION else tests .integ .PYTHON_VERSION
46- )
44+ def py_version (tf_full_version , tf_serving_version ):
45+ return "py37" if tf_full_version == tf_serving_version else tests .integ .PYTHON_VERSION
4746
4847
4948def test_mnist_with_checkpoint_config (
@@ -61,7 +60,7 @@ def test_mnist_with_checkpoint_config(
6160 sagemaker_session = sagemaker_session ,
6261 script_mode = True ,
6362 framework_version = tf_full_version ,
64- py_version = py_version ,
63+ py_version = "py37" ,
6564 metric_definitions = [{"Name" : "train:global_steps" , "Regex" : r"global_step\/sec:\s(.*)" }],
6665 checkpoint_s3_uri = checkpoint_s3_uri ,
6766 checkpoint_local_path = checkpoint_local_path ,
@@ -91,7 +90,7 @@ def test_mnist_with_checkpoint_config(
9190 assert actual_training_checkpoint_config == expected_training_checkpoint_config
9291
9392
94- def test_server_side_encryption (sagemaker_session , tf_full_version , py_version ):
93+ def test_server_side_encryption (sagemaker_session , tf_serving_version , py_version ):
9594 with kms_utils .bucket_with_encryption (sagemaker_session , ROLE ) as (bucket_with_kms , kms_key ):
9695 output_path = os .path .join (
9796 bucket_with_kms , "test-server-side-encryption" , time .strftime ("%y%m%d-%H%M" )
@@ -105,7 +104,7 @@ def test_server_side_encryption(sagemaker_session, tf_full_version, py_version):
105104 train_instance_type = "ml.c5.xlarge" ,
106105 sagemaker_session = sagemaker_session ,
107106 script_mode = True ,
108- framework_version = tf_full_version ,
107+ framework_version = tf_serving_version ,
109108 py_version = py_version ,
110109 code_location = output_path ,
111110 output_path = output_path ,
@@ -140,7 +139,7 @@ def test_mnist_distributed(sagemaker_session, instance_type, tf_full_version, py
140139 train_instance_count = 2 ,
141140 train_instance_type = instance_type ,
142141 sagemaker_session = sagemaker_session ,
143- py_version = py_version ,
142+ py_version = "py37" ,
144143 script_mode = True ,
145144 framework_version = tf_full_version ,
146145 distributions = PARAMETER_SERVER_DISTRIBUTION ,
@@ -168,7 +167,7 @@ def test_mnist_async(sagemaker_session, cpu_instance_type, tf_full_version, py_v
168167 sagemaker_session = sagemaker_session ,
169168 script_mode = True ,
170169 # testing py-sdk functionality, no need to run against all TF versions
171- framework_version = TensorFlow . LATEST_VERSION ,
170+ framework_version = LATEST_SERVING_VERSION ,
172171 tags = TAGS ,
173172 )
174173 inputs = estimator .sagemaker_session .upload_data (
@@ -200,7 +199,9 @@ def test_mnist_async(sagemaker_session, cpu_instance_type, tf_full_version, py_v
200199 _assert_model_name_match (sagemaker_session .sagemaker_client , endpoint_name , model_name )
201200
202201
203- def test_deploy_with_input_handlers (sagemaker_session , instance_type , tf_full_version , py_version ):
202+ def test_deploy_with_input_handlers (
203+ sagemaker_session , instance_type , tf_serving_version , py_version
204+ ):
204205 estimator = TensorFlow (
205206 entry_point = "training.py" ,
206207 source_dir = TFS_RESOURCE_PATH ,
@@ -210,7 +211,7 @@ def test_deploy_with_input_handlers(sagemaker_session, instance_type, tf_full_ve
210211 py_version = py_version ,
211212 sagemaker_session = sagemaker_session ,
212213 script_mode = True ,
213- framework_version = tf_full_version ,
214+ framework_version = tf_serving_version ,
214215 tags = TAGS ,
215216 )
216217
0 commit comments