2323from sagemaker .utils import unique_name_from_base , sagemaker_timestamp
2424
2525import tests .integ
26- from tests .integ import timeout
27- from tests .integ import kms_utils
26+ from tests .integ import kms_utils , timeout , PYTHON_VERSION
2827from tests .integ .retry import retries
2928from tests .integ .s3_utils import assert_s3_files_exist
3029
4039TAGS = [{"Key" : "some-key" , "Value" : "some-value" }]
4140
4241
43- @pytest .fixture (scope = "module" )
44- def py_version (tf_full_version , tf_serving_version ):
45- return "py37" if tf_full_version == tf_serving_version else tests .integ .PYTHON_VERSION
46-
47-
4842def test_mnist_with_checkpoint_config (
49- sagemaker_session , instance_type , tf_full_version , py_version
43+ sagemaker_session , instance_type , tf_full_version , tf_full_py_version
5044):
5145 checkpoint_s3_uri = "s3://{}/checkpoints/tf-{}" .format (
5246 sagemaker_session .default_bucket (), sagemaker_timestamp ()
@@ -59,7 +53,7 @@ def test_mnist_with_checkpoint_config(
5953 train_instance_type = instance_type ,
6054 sagemaker_session = sagemaker_session ,
6155 framework_version = tf_full_version ,
62- py_version = "py37" ,
56+ py_version = tf_full_py_version ,
6357 metric_definitions = [{"Name" : "train:global_steps" , "Regex" : r"global_step\/sec:\s(.*)" }],
6458 checkpoint_s3_uri = checkpoint_s3_uri ,
6559 checkpoint_local_path = checkpoint_local_path ,
@@ -89,7 +83,7 @@ def test_mnist_with_checkpoint_config(
8983 assert actual_training_checkpoint_config == expected_training_checkpoint_config
9084
9185
92- def test_server_side_encryption (sagemaker_session , tf_serving_version , py_version ):
86+ def test_server_side_encryption (sagemaker_session , tf_serving_version ):
9387 with kms_utils .bucket_with_encryption (sagemaker_session , ROLE ) as (bucket_with_kms , kms_key ):
9488 output_path = os .path .join (
9589 bucket_with_kms , "test-server-side-encryption" , time .strftime ("%y%m%d-%H%M" )
@@ -103,7 +97,7 @@ def test_server_side_encryption(sagemaker_session, tf_serving_version, py_versio
10397 train_instance_type = "ml.c5.xlarge" ,
10498 sagemaker_session = sagemaker_session ,
10599 framework_version = tf_serving_version ,
106- py_version = py_version ,
100+ py_version = PYTHON_VERSION ,
107101 code_location = output_path ,
108102 output_path = output_path ,
109103 model_dir = "/opt/ml/model" ,
@@ -130,15 +124,15 @@ def test_server_side_encryption(sagemaker_session, tf_serving_version, py_versio
130124
131125
132126@pytest .mark .canary_quick
133- def test_mnist_distributed (sagemaker_session , instance_type , tf_full_version , py_version ):
127+ def test_mnist_distributed (sagemaker_session , instance_type , tf_full_version , tf_full_py_version ):
134128 estimator = TensorFlow (
135129 entry_point = SCRIPT ,
136130 role = ROLE ,
137131 train_instance_count = 2 ,
138132 train_instance_type = instance_type ,
139133 sagemaker_session = sagemaker_session ,
140- py_version = "py37" ,
141134 framework_version = tf_full_version ,
135+ py_version = tf_full_py_version ,
142136 distributions = PARAMETER_SERVER_DISTRIBUTION ,
143137 )
144138 inputs = estimator .sagemaker_session .upload_data (
@@ -154,13 +148,13 @@ def test_mnist_distributed(sagemaker_session, instance_type, tf_full_version, py
154148 )
155149
156150
157- def test_mnist_async (sagemaker_session , cpu_instance_type , tf_full_version , py_version ):
151+ def test_mnist_async (sagemaker_session , cpu_instance_type ):
158152 estimator = TensorFlow (
159153 entry_point = SCRIPT ,
160154 role = ROLE ,
161155 train_instance_count = 1 ,
162156 train_instance_type = "ml.c5.4xlarge" ,
163- py_version = tests . integ . PYTHON_VERSION ,
157+ py_version = PYTHON_VERSION ,
164158 sagemaker_session = sagemaker_session ,
165159 # testing py-sdk functionality, no need to run against all TF versions
166160 framework_version = LATEST_SERVING_VERSION ,
@@ -195,18 +189,16 @@ def test_mnist_async(sagemaker_session, cpu_instance_type, tf_full_version, py_v
195189 _assert_model_name_match (sagemaker_session .sagemaker_client , endpoint_name , model_name )
196190
197191
198- def test_deploy_with_input_handlers (
199- sagemaker_session , instance_type , tf_serving_version , py_version
200- ):
192+ def test_deploy_with_input_handlers (sagemaker_session , instance_type , tf_serving_version ):
201193 estimator = TensorFlow (
202194 entry_point = "training.py" ,
203195 source_dir = TFS_RESOURCE_PATH ,
204196 role = ROLE ,
205197 train_instance_count = 1 ,
206198 train_instance_type = instance_type ,
207- py_version = py_version ,
208- sagemaker_session = sagemaker_session ,
209199 framework_version = tf_serving_version ,
200+ py_version = PYTHON_VERSION ,
201+ sagemaker_session = sagemaker_session ,
210202 tags = TAGS ,
211203 )
212204
0 commit comments