@@ -151,7 +151,11 @@ def test_deploy_model_with_tags_and_kms(
151151
152152
153153def test_deploy_model_with_update_endpoint (
154- mxnet_training_job , sagemaker_session , mxnet_full_version , cpu_instance_type
154+ mxnet_training_job ,
155+ sagemaker_session ,
156+ mxnet_full_version ,
157+ cpu_instance_type ,
158+ alternative_cpu_instance_type ,
155159):
156160 endpoint_name = "test-mxnet-deploy-model-{}" .format (sagemaker_timestamp ())
157161
@@ -169,13 +173,13 @@ def test_deploy_model_with_update_endpoint(
169173 sagemaker_session = sagemaker_session ,
170174 framework_version = mxnet_full_version ,
171175 )
172- model .deploy (1 , "ml.t2.medium" , endpoint_name = endpoint_name )
176+ model .deploy (1 , alternative_cpu_instance_type , endpoint_name = endpoint_name )
173177 old_endpoint = sagemaker_session .sagemaker_client .describe_endpoint (
174178 EndpointName = endpoint_name
175179 )
176180 old_config_name = old_endpoint ["EndpointConfigName" ]
177181
178- model .deploy (1 , "ml.m4.xlarge" , update_endpoint = True , endpoint_name = endpoint_name )
182+ model .deploy (1 , cpu_instance_type , update_endpoint = True , endpoint_name = endpoint_name )
179183
180184 # Wait for endpoint to finish updating
181185 max_retry_count = 40 # Endpoint update takes ~7min. 40 retries * 30s sleeps = 20min timeout
@@ -197,12 +201,16 @@ def test_deploy_model_with_update_endpoint(
197201 )
198202
199203 assert old_config_name != new_config_name
200- assert new_config ["ProductionVariants" ][0 ]["InstanceType" ] == "ml.m4.xlarge"
204+ assert new_config ["ProductionVariants" ][0 ]["InstanceType" ] == cpu_instance_type
201205 assert new_config ["ProductionVariants" ][0 ]["InitialInstanceCount" ] == 1
202206
203207
204208def test_deploy_model_with_update_non_existing_endpoint (
205- mxnet_training_job , sagemaker_session , mxnet_full_version , cpu_instance_type
209+ mxnet_training_job ,
210+ sagemaker_session ,
211+ mxnet_full_version ,
212+ cpu_instance_type ,
213+ alternative_cpu_instance_type ,
206214):
207215 endpoint_name = "test-mxnet-deploy-model-{}" .format (sagemaker_timestamp ())
208216 expected_error_message = (
@@ -224,7 +232,7 @@ def test_deploy_model_with_update_non_existing_endpoint(
224232 sagemaker_session = sagemaker_session ,
225233 framework_version = mxnet_full_version ,
226234 )
227- model .deploy (1 , "ml.t2.medium" , endpoint_name = endpoint_name )
235+ model .deploy (1 , alternative_cpu_instance_type , endpoint_name = endpoint_name )
228236 sagemaker_session .sagemaker_client .describe_endpoint (EndpointName = endpoint_name )
229237
230238 with pytest .raises (ValueError , message = expected_error_message ):
0 commit comments