3939TAGS = [{"Key" : "some-key" , "Value" : "some-value" }]
4040
4141
42- def test_mnist_with_checkpoint_config (sagemaker_session , instance_type ):
42+ def test_mnist_with_checkpoint_config (sagemaker_session , instance_type , tf_full_version ):
4343 checkpoint_s3_uri = "s3://{}/checkpoints/tf-{}" .format (
4444 sagemaker_session .default_bucket (), sagemaker_timestamp ()
4545 )
@@ -51,7 +51,7 @@ def test_mnist_with_checkpoint_config(sagemaker_session, instance_type):
5151 train_instance_type = instance_type ,
5252 sagemaker_session = sagemaker_session ,
5353 script_mode = True ,
54- framework_version = TensorFlow . LATEST_VERSION ,
54+ framework_version = tf_full_version ,
5555 py_version = tests .integ .PYTHON_VERSION ,
5656 metric_definitions = [{"Name" : "train:global_steps" , "Regex" : r"global_step\/sec:\s(.*)" }],
5757 checkpoint_s3_uri = checkpoint_s3_uri ,
@@ -82,7 +82,7 @@ def test_mnist_with_checkpoint_config(sagemaker_session, instance_type):
8282 assert actual_training_checkpoint_config == expected_training_checkpoint_config
8383
8484
85- def test_server_side_encryption (sagemaker_session ):
85+ def test_server_side_encryption (sagemaker_session , tf_full_version ):
8686 boto_session = sagemaker_session .boto_session
8787 with kms_utils .bucket_with_encryption (boto_session , ROLE ) as (bucket_with_kms , kms_key ):
8888 output_path = os .path .join (
@@ -97,7 +97,7 @@ def test_server_side_encryption(sagemaker_session):
9797 train_instance_type = "ml.c5.xlarge" ,
9898 sagemaker_session = sagemaker_session ,
9999 script_mode = True ,
100- framework_version = TensorFlow . LATEST_VERSION ,
100+ framework_version = tf_full_version ,
101101 py_version = tests .integ .PYTHON_VERSION ,
102102 code_location = output_path ,
103103 output_path = output_path ,
@@ -125,7 +125,7 @@ def test_server_side_encryption(sagemaker_session):
125125
126126
127127@pytest .mark .canary_quick
128- def test_mnist_distributed (sagemaker_session , instance_type ):
128+ def test_mnist_distributed (sagemaker_session , instance_type , tf_full_version ):
129129 estimator = TensorFlow (
130130 entry_point = SCRIPT ,
131131 role = ROLE ,
@@ -134,7 +134,7 @@ def test_mnist_distributed(sagemaker_session, instance_type):
134134 sagemaker_session = sagemaker_session ,
135135 py_version = tests .integ .PYTHON_VERSION ,
136136 script_mode = True ,
137- framework_version = TensorFlow . LATEST_VERSION ,
137+ framework_version = tf_full_version ,
138138 distributions = PARAMETER_SERVER_DISTRIBUTION ,
139139 )
140140 inputs = estimator .sagemaker_session .upload_data (
@@ -159,6 +159,7 @@ def test_mnist_async(sagemaker_session, cpu_instance_type):
159159 py_version = tests .integ .PYTHON_VERSION ,
160160 sagemaker_session = sagemaker_session ,
161161 script_mode = True ,
162+ # testing py-sdk functionality, no need to run against all TF versions
162163 framework_version = TensorFlow .LATEST_VERSION ,
163164 tags = TAGS ,
164165 )
@@ -191,7 +192,7 @@ def test_mnist_async(sagemaker_session, cpu_instance_type):
191192 _assert_model_name_match (sagemaker_session .sagemaker_client , endpoint_name , model_name )
192193
193194
194- def test_deploy_with_input_handlers (sagemaker_session , instance_type ):
195+ def test_deploy_with_input_handlers (sagemaker_session , instance_type , tf_full_version ):
195196 estimator = TensorFlow (
196197 entry_point = "training.py" ,
197198 source_dir = TFS_RESOURCE_PATH ,
@@ -201,7 +202,7 @@ def test_deploy_with_input_handlers(sagemaker_session, instance_type):
201202 py_version = tests .integ .PYTHON_VERSION ,
202203 sagemaker_session = sagemaker_session ,
203204 script_mode = True ,
204- framework_version = TensorFlow . LATEST_VERSION ,
205+ framework_version = tf_full_version ,
205206 tags = TAGS ,
206207 )
207208
0 commit comments