2222from sagemaker .tensorflow import TensorFlow
2323from six .moves .urllib .parse import urlparse
2424from sagemaker .utils import unique_name_from_base
25- import tests .integ as integ
26- from tests .integ import kms_utils
27- import tests .integ .timeout as timeout
25+
26+ import tests .integ
2827
2928ROLE = 'SageMakerRole'
3029
3534TAGS = [{'Key' : 'some-key' , 'Value' : 'some-value' }]
3635
3736
38- @pytest .fixture (scope = 'session' , params = ['ml.c5.xlarge' , 'ml.p2.xlarge' ])
37+ @pytest .fixture (scope = 'session' , params = [
38+ 'ml.c5.xlarge' ,
39+ pytest .param ('ml.p2.xlarge' ,
40+ marks = pytest .mark .skipif (
41+ tests .integ .test_region () in tests .integ .HOSTING_NO_P2_REGIONS ,
42+ reason = 'no ml.p2 instances in this region' ))])
3943def instance_type (request ):
4044 return request .param
4145
4246
43- @pytest .mark .skipif (integ .test_region () in integ .HOSTING_NO_P2_REGIONS ,
44- reason = 'no ml.p2 instances in these regions' )
45- @pytest .mark .skipif (integ .PYTHON_VERSION != 'py3' , reason = "Script Mode tests are only configured to run with Python 3" )
47+ @pytest .mark .skipif (tests .integ .PYTHON_VERSION != 'py3' ,
48+ reason = "Script Mode tests are only configured to run with Python 3" )
4649def test_mnist (sagemaker_session , instance_type ):
4750 estimator = TensorFlow (entry_point = SCRIPT ,
4851 role = 'SageMakerRole' ,
@@ -51,26 +54,26 @@ def test_mnist(sagemaker_session, instance_type):
5154 sagemaker_session = sagemaker_session ,
5255 py_version = 'py3' ,
5356 framework_version = TensorFlow .LATEST_VERSION ,
54- metric_definitions = [{'Name' : 'train:global_steps' , 'Regex' : r'global_step\/sec:\s(.*)' }])
57+ metric_definitions = [
58+ {'Name' : 'train:global_steps' , 'Regex' : r'global_step\/sec:\s(.*)' }])
5559 inputs = estimator .sagemaker_session .upload_data (
5660 path = os .path .join (RESOURCE_PATH , 'data' ),
5761 key_prefix = 'scriptmode/mnist' )
5862
59- with timeout .timeout (minutes = integ .TRAINING_DEFAULT_TIMEOUT_MINUTES ):
63+ with tests . integ . timeout .timeout (minutes = tests . integ .TRAINING_DEFAULT_TIMEOUT_MINUTES ):
6064 estimator .fit (inputs = inputs , job_name = unique_name_from_base ('test-tf-sm-mnist' ))
6165 _assert_s3_files_exist (estimator .model_dir ,
6266 ['graph.pbtxt' , 'model.ckpt-0.index' , 'model.ckpt-0.meta' ])
6367 df = estimator .training_job_analytics .dataframe ()
64- print (df )
6568 assert df .size > 0
6669
6770
6871def test_server_side_encryption (sagemaker_session ):
69-
7072 boto_session = sagemaker_session .boto_session
71- with kms_utils .bucket_with_encryption (boto_session , ROLE ) as (bucket_with_kms , kms_key ):
72-
73- output_path = os .path .join (bucket_with_kms , 'test-server-side-encryption' , time .strftime ('%y%m%d-%H%M' ))
73+ with tests .integ .kms_utils .bucket_with_encryption (boto_session , ROLE ) as (
74+ bucket_with_kms , kms_key ):
75+ output_path = os .path .join (bucket_with_kms , 'test-server-side-encryption' ,
76+ time .strftime ('%y%m%d-%H%M' ))
7477
7578 estimator = TensorFlow (entry_point = SCRIPT ,
7679 role = ROLE ,
@@ -88,28 +91,29 @@ def test_server_side_encryption(sagemaker_session):
8891 path = os .path .join (RESOURCE_PATH , 'data' ),
8992 key_prefix = 'scriptmode/mnist' )
9093
91- with timeout .timeout (minutes = integ .TRAINING_DEFAULT_TIMEOUT_MINUTES ):
92- estimator .fit (inputs = inputs , job_name = unique_name_from_base ('test-server-side-encryption' ))
94+ with tests .integ .timeout .timeout (minutes = tests .integ .TRAINING_DEFAULT_TIMEOUT_MINUTES ):
95+ estimator .fit (inputs = inputs ,
96+ job_name = unique_name_from_base ('test-server-side-encryption' ))
9397
9498
9599@pytest .mark .canary_quick
96- @pytest .mark .skipif (integ .PYTHON_VERSION != 'py3' , reason = "Script Mode tests are only configured to run with Python 3" )
100+ @pytest .mark .skipif (tests .integ .PYTHON_VERSION != 'py3' ,
101+ reason = "Script Mode tests are only configured to run with Python 3" )
97102def test_mnist_distributed (sagemaker_session , instance_type ):
98103 estimator = TensorFlow (entry_point = SCRIPT ,
99104 role = ROLE ,
100105 train_instance_count = 2 ,
101- # TODO: change train_instance_type to instance_type once the test is passing consistently
102- train_instance_type = 'ml.c5.xlarge' ,
106+ train_instance_type = instance_type ,
103107 sagemaker_session = sagemaker_session ,
104- py_version = integ .PYTHON_VERSION ,
108+ py_version = tests . integ .PYTHON_VERSION ,
105109 script_mode = True ,
106110 framework_version = TensorFlow .LATEST_VERSION ,
107111 distributions = PARAMETER_SERVER_DISTRIBUTION )
108112 inputs = estimator .sagemaker_session .upload_data (
109113 path = os .path .join (RESOURCE_PATH , 'data' ),
110114 key_prefix = 'scriptmode/distributed_mnist' )
111115
112- with timeout .timeout (minutes = integ .TRAINING_DEFAULT_TIMEOUT_MINUTES ):
116+ with tests . integ . timeout .timeout (minutes = tests . integ .TRAINING_DEFAULT_TIMEOUT_MINUTES ):
113117 estimator .fit (inputs = inputs , job_name = unique_name_from_base ('test-tf-sm-distributed' ))
114118 _assert_s3_files_exist (estimator .model_dir ,
115119 ['graph.pbtxt' , 'model.ckpt-0.index' , 'model.ckpt-0.meta' ])
@@ -131,22 +135,26 @@ def test_mnist_async(sagemaker_session):
131135 training_job_name = estimator .latest_training_job .name
132136 time .sleep (20 )
133137 endpoint_name = training_job_name
134- _assert_training_job_tags_match (sagemaker_session .sagemaker_client , estimator .latest_training_job .name , TAGS )
135- with timeout .timeout_and_delete_endpoint_by_name (endpoint_name , sagemaker_session ):
136- estimator = TensorFlow .attach (training_job_name = training_job_name , sagemaker_session = sagemaker_session )
138+ _assert_training_job_tags_match (sagemaker_session .sagemaker_client ,
139+ estimator .latest_training_job .name , TAGS )
140+ with tests .integ .timeout .timeout_and_delete_endpoint_by_name (endpoint_name , sagemaker_session ):
141+ estimator = TensorFlow .attach (training_job_name = training_job_name ,
142+ sagemaker_session = sagemaker_session )
137143 predictor = estimator .deploy (initial_instance_count = 1 , instance_type = 'ml.c4.xlarge' ,
138144 endpoint_name = endpoint_name )
139145
140146 result = predictor .predict (np .zeros (784 ))
141147 print ('predict result: {}' .format (result ))
142148 _assert_endpoint_tags_match (sagemaker_session .sagemaker_client , predictor .endpoint , TAGS )
143- _assert_model_tags_match (sagemaker_session .sagemaker_client , estimator .latest_training_job .name , TAGS )
149+ _assert_model_tags_match (sagemaker_session .sagemaker_client ,
150+ estimator .latest_training_job .name , TAGS )
144151
145152
146153def _assert_s3_files_exist (s3_url , files ):
147154 parsed_url = urlparse (s3_url )
148155 s3 = boto3 .client ('s3' )
149- contents = s3 .list_objects_v2 (Bucket = parsed_url .netloc , Prefix = parsed_url .path .lstrip ('/' ))["Contents" ]
156+ contents = s3 .list_objects_v2 (Bucket = parsed_url .netloc , Prefix = parsed_url .path .lstrip ('/' ))[
157+ "Contents" ]
150158 for f in files :
151159 found = [x ['Key' ] for x in contents if x ['Key' ].endswith (f )]
152160 if not found :
@@ -169,5 +177,6 @@ def _assert_endpoint_tags_match(sagemaker_client, endpoint_name, tags):
169177
170178
171179def _assert_training_job_tags_match (sagemaker_client , training_job_name , tags ):
172- training_job_description = sagemaker_client .describe_training_job (TrainingJobName = training_job_name )
180+ training_job_description = sagemaker_client .describe_training_job (
181+ TrainingJobName = training_job_name )
173182 _assert_tags_match (sagemaker_client , training_job_description ['TrainingJobArn' ], tags )
0 commit comments