@@ -96,10 +96,6 @@ def test_deploy_model(mxnet_training_job, sagemaker_session, mxnet_full_version)
9696 assert "Could not find model" in str (exception .value )
9797
9898
99- @pytest .mark .skip (
100- reason = "This test has always failed, but the failure was masked by a bug. "
101- "This test should be fixed. Details in https://github.com/aws/sagemaker-python-sdk/pull/968"
102- )
10399def test_deploy_model_with_tags_and_kms (mxnet_training_job , sagemaker_session , mxnet_full_version ):
104100 endpoint_name = "test-mxnet-deploy-model-{}" .format (sagemaker_timestamp ())
105101
@@ -123,18 +119,20 @@ def test_deploy_model_with_tags_and_kms(mxnet_training_job, sagemaker_session, m
123119
124120 model .deploy (1 , "ml.m4.xlarge" , endpoint_name = endpoint_name , tags = tags , kms_key = kms_key_arn )
125121
126- returned_model = sagemaker_session .describe_model (EndpointName = model .name )
127- returned_model_tags = sagemaker_session .list_tags (ResourceArn = returned_model [ "ModelArn" ])[
128- "Tags"
129- ]
122+ returned_model = sagemaker_session .sagemaker_client . describe_model (ModelName = model .name )
123+ returned_model_tags = sagemaker_session .sagemaker_client . list_tags (
124+ ResourceArn = returned_model [ "ModelArn" ]
125+ )[ "Tags" ]
130126
131- endpoint = sagemaker_session .describe_endpoint (EndpointName = endpoint_name )
132- endpoint_tags = sagemaker_session .list_tags (ResourceArn = endpoint ["EndpointArn" ])["Tags" ]
127+ endpoint = sagemaker_session .sagemaker_client .describe_endpoint (EndpointName = endpoint_name )
128+ endpoint_tags = sagemaker_session .sagemaker_client .list_tags (
129+ ResourceArn = endpoint ["EndpointArn" ]
130+ )["Tags" ]
133131
134- endpoint_config = sagemaker_session .describe_endpoint_config (
132+ endpoint_config = sagemaker_session .sagemaker_client . describe_endpoint_config (
135133 EndpointConfigName = endpoint ["EndpointConfigName" ]
136134 )
137- endpoint_config_tags = sagemaker_session .list_tags (
135+ endpoint_config_tags = sagemaker_session .sagemaker_client . list_tags (
138136 ResourceArn = endpoint_config ["EndpointConfigArn" ]
139137 )["Tags" ]
140138
@@ -148,10 +146,6 @@ def test_deploy_model_with_tags_and_kms(mxnet_training_job, sagemaker_session, m
148146 assert endpoint_config ["KmsKeyId" ] == kms_key_arn
149147
150148
151- @pytest .mark .skip (
152- reason = "This test has always failed, but the failure was masked by a bug. "
153- "This test should be fixed. Details in https://github.com/aws/sagemaker-python-sdk/pull/968"
154- )
155149def test_deploy_model_with_update_endpoint (
156150 mxnet_training_job , sagemaker_session , mxnet_full_version
157151):
@@ -172,26 +166,37 @@ def test_deploy_model_with_update_endpoint(
172166 framework_version = mxnet_full_version ,
173167 )
174168 model .deploy (1 , "ml.t2.medium" , endpoint_name = endpoint_name )
175- old_endpoint = sagemaker_session .describe_endpoint (EndpointName = endpoint_name )
169+ old_endpoint = sagemaker_session .sagemaker_client .describe_endpoint (
170+ EndpointName = endpoint_name
171+ )
176172 old_config_name = old_endpoint ["EndpointConfigName" ]
177173
178174 model .deploy (1 , "ml.m4.xlarge" , update_endpoint = True , endpoint_name = endpoint_name )
179- new_endpoint = sagemaker_session .describe_endpoint (EndpointName = endpoint_name )[
180- "ProductionVariants"
181- ]
182- new_production_variants = new_endpoint ["ProductionVariants" ]
175+
176+ # Wait for endpoint to finish updating
177+ max_retry_count = 40 # Endpoint update takes ~7min. 40 retries * 30s sleeps = 20min timeout
178+ current_retry_count = 0
179+ while current_retry_count <= max_retry_count :
180+ if current_retry_count >= max_retry_count :
181+ raise Exception ("Endpoint status not 'InService' within expected timeout." )
182+ time .sleep (30 )
183+ new_endpoint = sagemaker_session .sagemaker_client .describe_endpoint (
184+ EndpointName = endpoint_name
185+ )
186+ current_retry_count += 1
187+ if new_endpoint ["EndpointStatus" ] == "InService" :
188+ break
189+
183190 new_config_name = new_endpoint ["EndpointConfigName" ]
191+ new_config = sagemaker_session .sagemaker_client .describe_endpoint_config (
192+ EndpointConfigName = new_config_name
193+ )
184194
185195 assert old_config_name != new_config_name
186- assert new_production_variants ["InstanceType" ] == "ml.m4.xlarge"
187- assert new_production_variants ["InitialInstanceCount" ] == 1
188- assert new_production_variants ["AcceleratorType" ] is None
196+ assert new_config ["ProductionVariants" ][0 ]["InstanceType" ] == "ml.m4.xlarge"
197+ assert new_config ["ProductionVariants" ][0 ]["InitialInstanceCount" ] == 1
189198
190199
191- @pytest .mark .skip (
192- reason = "This test has always failed, but the failure was masked by a bug. "
193- "This test should be fixed. Details in https://github.com/aws/sagemaker-python-sdk/pull/968"
194- )
195200def test_deploy_model_with_update_non_existing_endpoint (
196201 mxnet_training_job , sagemaker_session , mxnet_full_version
197202):
@@ -216,7 +221,7 @@ def test_deploy_model_with_update_non_existing_endpoint(
216221 framework_version = mxnet_full_version ,
217222 )
218223 model .deploy (1 , "ml.t2.medium" , endpoint_name = endpoint_name )
219- sagemaker_session .describe_endpoint (EndpointName = endpoint_name )
224+ sagemaker_session .sagemaker_client . describe_endpoint (EndpointName = endpoint_name )
220225
221226 with pytest .raises (ValueError , message = expected_error_message ):
222227 model .deploy (
0 commit comments