@@ -182,21 +182,51 @@ def test_server_side_encryption(sagemaker_session, tf_full_version, tf_full_py_v
182182
183183
184184@pytest .mark .release
185- def test_mnist_distributed (
185+ def test_mnist_distributed_cpu (
186186 sagemaker_session ,
187- instance_type ,
187+ cpu_instance_type ,
188188 tensorflow_training_latest_version ,
189189 tensorflow_training_latest_py_version ,
190190):
191+ _create_and_fit_estimator (
192+ sagemaker_session ,
193+ tensorflow_training_latest_version ,
194+ tensorflow_training_latest_py_version ,
195+ cpu_instance_type ,
196+ )
197+
198+
199+ @pytest .mark .release
200+ @pytest .mark .skipif (
201+ tests .integ .test_region () in tests .integ .TRAINING_NO_P2_REGIONS
202+ and tests .integ .test_region () in tests .integ .TRAINING_NO_P3_REGIONS ,
203+ reason = "no ml.p2 or ml.p3 instances in this region" ,
204+ )
205+ @retry_with_instance_list (gpu_list (tests .integ .test_region ()))
206+ def test_mnist_distributed_gpu (
207+ sagemaker_session ,
208+ tensorflow_training_latest_version ,
209+ tensorflow_training_latest_py_version ,
210+ ** kwargs ,
211+ ):
212+ _create_and_fit_estimator (
213+ sagemaker_session ,
214+ tensorflow_training_latest_version ,
215+ tensorflow_training_latest_py_version ,
216+ kwargs ["instance_type" ],
217+ )
218+
219+
220+ def _create_and_fit_estimator (sagemaker_session , tf_version , py_version , instance_type ):
191221 estimator = TensorFlow (
192222 entry_point = SCRIPT ,
193223 source_dir = MNIST_RESOURCE_PATH ,
194224 role = ROLE ,
195225 instance_count = 2 ,
196226 instance_type = instance_type ,
197227 sagemaker_session = sagemaker_session ,
198- framework_version = tensorflow_training_latest_version ,
199- py_version = tensorflow_training_latest_py_version ,
228+ framework_version = tf_version ,
229+ py_version = py_version ,
200230 distribution = PARAMETER_SERVER_DISTRIBUTION ,
201231 disable_profiler = True ,
202232 )
0 commit comments