2121import boto3
2222from sagemaker .tensorflow import TensorFlow
2323from six .moves .urllib .parse import urlparse
24+ from sagemaker .utils import unique_name_from_base
2425import tests .integ as integ
2526from tests .integ import kms_utils
2627import tests .integ .timeout as timeout
3132SCRIPT = os .path .join (RESOURCE_PATH , 'mnist.py' )
3233PARAMETER_SERVER_DISTRIBUTION = {'parameter_server' : {'enabled' : True }}
3334MPI_DISTRIBUTION = {'mpi' : {'enabled' : True }}
35+ TAGS = [{'Key' : 'some-key' , 'Value' : 'some-value' }]
3436
3537
3638@pytest .fixture (scope = 'session' , params = ['ml.c5.xlarge' , 'ml.p2.xlarge' ])
@@ -48,7 +50,7 @@ def test_mnist(sagemaker_session, instance_type):
4850 py_version = 'py3' ,
4951 framework_version = TensorFlow .LATEST_VERSION ,
5052 metric_definitions = [{'Name' : 'train:global_steps' , 'Regex' : r'global_step\/sec:\s(.*)' }],
51- base_job_name = 'test-tf-sm-mnist' )
53+ base_job_name = unique_name_from_base ( 'test-tf-sm-mnist' ) )
5254 inputs = estimator .sagemaker_session .upload_data (
5355 path = os .path .join (RESOURCE_PATH , 'data' ),
5456 key_prefix = 'scriptmode/mnist' )
@@ -76,7 +78,7 @@ def test_server_side_encryption(sagemaker_session):
7678 sagemaker_session = sagemaker_session ,
7779 py_version = 'py3' ,
7880 framework_version = '1.11' ,
79- base_job_name = 'test-server-side-encryption' ,
81+ base_job_name = unique_name_from_base ( 'test-server-side-encryption' ) ,
8082 code_location = output_path ,
8183 output_path = output_path ,
8284 model_dir = '/opt/ml/model' ,
@@ -103,7 +105,7 @@ def test_mnist_distributed(sagemaker_session, instance_type):
103105 script_mode = True ,
104106 framework_version = TensorFlow .LATEST_VERSION ,
105107 distributions = PARAMETER_SERVER_DISTRIBUTION ,
106- base_job_name = 'test-tf-sm-mnist' )
108+ base_job_name = unique_name_from_base ( 'test-tf-sm-mnist' ) )
107109 inputs = estimator .sagemaker_session .upload_data (
108110 path = os .path .join (RESOURCE_PATH , 'data' ),
109111 key_prefix = 'scriptmode/distributed_mnist' )
@@ -122,21 +124,25 @@ def test_mnist_async(sagemaker_session):
122124 sagemaker_session = sagemaker_session ,
123125 py_version = 'py3' ,
124126 framework_version = TensorFlow .LATEST_VERSION ,
125- base_job_name = 'test-tf-sm-mnist' )
127+ base_job_name = unique_name_from_base ('test-tf-sm-mnist' ),
128+ tags = TAGS )
126129 inputs = estimator .sagemaker_session .upload_data (
127130 path = os .path .join (RESOURCE_PATH , 'data' ),
128131 key_prefix = 'scriptmode/mnist' )
129132 estimator .fit (inputs , wait = False )
130133 training_job_name = estimator .latest_training_job .name
131134 time .sleep (20 )
132135 endpoint_name = training_job_name
136+ _assert_training_job_tags_match (sagemaker_session .sagemaker_client , estimator .latest_training_job .name , TAGS )
133137 with timeout .timeout_and_delete_endpoint_by_name (endpoint_name , sagemaker_session ):
134138 estimator = TensorFlow .attach (training_job_name = training_job_name , sagemaker_session = sagemaker_session )
135139 predictor = estimator .deploy (initial_instance_count = 1 , instance_type = 'ml.c4.xlarge' ,
136140 endpoint_name = endpoint_name )
137141
138142 result = predictor .predict (np .zeros (784 ))
139143 print ('predict result: {}' .format (result ))
144+ _assert_endpoint_tags_match (sagemaker_session .sagemaker_client , predictor .endpoint , TAGS )
145+ _assert_model_tags_match (sagemaker_session .sagemaker_client , estimator .latest_training_job .name , TAGS )
140146
141147
142148def _assert_s3_files_exist (s3_url , files ):
@@ -147,3 +153,23 @@ def _assert_s3_files_exist(s3_url, files):
147153 found = [x ['Key' ] for x in contents if x ['Key' ].endswith (f )]
148154 if not found :
149155 raise ValueError ('File {} is not found under {}' .format (f , s3_url ))
156+
157+
158+ def _assert_tags_match (sagemaker_client , resource_arn , tags ):
159+ actual_tags = sagemaker_client .list_tags (ResourceArn = resource_arn )['Tags' ]
160+ assert actual_tags == tags
161+
162+
163+ def _assert_model_tags_match (sagemaker_client , model_name , tags ):
164+ model_description = sagemaker_client .describe_model (ModelName = model_name )
165+ _assert_tags_match (sagemaker_client , model_description ['ModelArn' ], tags )
166+
167+
168+ def _assert_endpoint_tags_match (sagemaker_client , endpoint_name , tags ):
169+ endpoint_description = sagemaker_client .describe_endpoint (EndpointName = endpoint_name )
170+ _assert_tags_match (sagemaker_client , endpoint_description ['EndpointArn' ], tags )
171+
172+
173+ def _assert_training_job_tags_match (sagemaker_client , training_job_name , tags ):
174+ training_job_description = sagemaker_client .describe_training_job (TrainingJobName = training_job_name )
175+ _assert_tags_match (sagemaker_client , training_job_description ['TrainingJobArn' ], tags )
0 commit comments