1717
1818import pytest
1919
20- import tests .integ
2120from sagemaker .inputs import FileSystemInput
2221from sagemaker .parameter import IntegerParameter
2322from sagemaker .tensorflow import TensorFlow
3231MNIST_RESOURCE_PATH = os .path .join (RESOURCE_PATH , "tensorflow_mnist" )
3332SCRIPT = os .path .join (MNIST_RESOURCE_PATH , "mnist.py" )
3433TFS_RESOURCE_PATH = os .path .join (RESOURCE_PATH , "tfs" , "tfs-test-entrypoint-with-handler" )
35- INSTANCE_TYPE = "ml.c4.xlarge"
3634EFS_DIR_PATH = "/tensorflow"
3735FSX_DIR_PATH = "/fsx/tensorflow"
3836MAX_JOBS = 2
@@ -49,11 +47,7 @@ def efs_fsx_setup(sagemaker_session):
4947 tear_down (sagemaker_session , fs_resources )
5048
5149
52- @pytest .mark .skipif (
53- tests .integ .test_region () not in tests .integ .EFS_TEST_ENABLED_REGION ,
54- reason = "EFS integration tests need to be fixed before running in all regions." ,
55- )
56- def test_mnist_efs (efs_fsx_setup , sagemaker_session ):
50+ def test_mnist_efs (efs_fsx_setup , sagemaker_session , cpu_instance_type ):
5751 role = efs_fsx_setup .role_name
5852 subnets = [efs_fsx_setup .subnet_id ]
5953 security_group_ids = efs_fsx_setup .security_group_ids
@@ -62,7 +56,7 @@ def test_mnist_efs(efs_fsx_setup, sagemaker_session):
6256 entry_point = SCRIPT ,
6357 role = role ,
6458 train_instance_count = 1 ,
65- train_instance_type = INSTANCE_TYPE ,
59+ train_instance_type = cpu_instance_type ,
6660 sagemaker_session = sagemaker_session ,
6761 script_mode = True ,
6862 framework_version = TensorFlow .LATEST_VERSION ,
@@ -85,11 +79,7 @@ def test_mnist_efs(efs_fsx_setup, sagemaker_session):
8579 )
8680
8781
88- @pytest .mark .skipif (
89- tests .integ .test_region () not in tests .integ .EFS_TEST_ENABLED_REGION ,
90- reason = "EFS integration tests need to be fixed before running in all regions." ,
91- )
92- def test_mnist_lustre (efs_fsx_setup , sagemaker_session ):
82+ def test_mnist_lustre (efs_fsx_setup , sagemaker_session , cpu_instance_type ):
9383 role = efs_fsx_setup .role_name
9484 subnets = [efs_fsx_setup .subnet_id ]
9585 security_group_ids = efs_fsx_setup .security_group_ids
@@ -98,7 +88,7 @@ def test_mnist_lustre(efs_fsx_setup, sagemaker_session):
9888 entry_point = SCRIPT ,
9989 role = role ,
10090 train_instance_count = 1 ,
101- train_instance_type = INSTANCE_TYPE ,
91+ train_instance_type = cpu_instance_type ,
10292 sagemaker_session = sagemaker_session ,
10393 script_mode = True ,
10494 framework_version = TensorFlow .LATEST_VERSION ,
@@ -121,11 +111,7 @@ def test_mnist_lustre(efs_fsx_setup, sagemaker_session):
121111 )
122112
123113
124- @pytest .mark .skipif (
125- tests .integ .test_region () not in tests .integ .EFS_TEST_ENABLED_REGION ,
126- reason = "EFS integration tests need to be fixed before running in all regions." ,
127- )
128- def test_tuning_tf_script_mode_efs (efs_fsx_setup , sagemaker_session ):
114+ def test_tuning_tf_script_mode_efs (efs_fsx_setup , sagemaker_session , cpu_instance_type ):
129115 role = efs_fsx_setup .role_name
130116 subnets = [efs_fsx_setup .subnet_id ]
131117 security_group_ids = efs_fsx_setup .security_group_ids
@@ -134,7 +120,7 @@ def test_tuning_tf_script_mode_efs(efs_fsx_setup, sagemaker_session):
134120 entry_point = SCRIPT ,
135121 role = role ,
136122 train_instance_count = 1 ,
137- train_instance_type = INSTANCE_TYPE ,
123+ train_instance_type = cpu_instance_type ,
138124 script_mode = True ,
139125 sagemaker_session = sagemaker_session ,
140126 py_version = PY_VERSION ,
@@ -169,11 +155,7 @@ def test_tuning_tf_script_mode_efs(efs_fsx_setup, sagemaker_session):
169155 assert best_training_job
170156
171157
172- @pytest .mark .skipif (
173- tests .integ .test_region () not in tests .integ .EFS_TEST_ENABLED_REGION ,
174- reason = "EFS integration tests need to be fixed before running in all regions." ,
175- )
176- def test_tuning_tf_script_mode_lustre (efs_fsx_setup , sagemaker_session ):
158+ def test_tuning_tf_script_mode_lustre (efs_fsx_setup , sagemaker_session , cpu_instance_type ):
177159 role = efs_fsx_setup .role_name
178160 subnets = [efs_fsx_setup .subnet_id ]
179161 security_group_ids = efs_fsx_setup .security_group_ids
@@ -182,7 +164,7 @@ def test_tuning_tf_script_mode_lustre(efs_fsx_setup, sagemaker_session):
182164 entry_point = SCRIPT ,
183165 role = role ,
184166 train_instance_count = 1 ,
185- train_instance_type = INSTANCE_TYPE ,
167+ train_instance_type = cpu_instance_type ,
186168 script_mode = True ,
187169 sagemaker_session = sagemaker_session ,
188170 py_version = PY_VERSION ,
0 commit comments