@@ -177,6 +177,67 @@ def test_multi_data_model_deploy_pretrained_models(
177177 assert "Could not find endpoint" in str (exception .value )
178178
179179
180+ @pytest .mark .local_mode
181+ def test_multi_data_model_deploy_pretrained_models_local_mode (container_image , sagemaker_session ):
182+ timestamp = sagemaker_timestamp ()
183+ endpoint_name = "test-multimodel-endpoint-{}" .format (timestamp )
184+ model_name = "test-multimodel-{}" .format (timestamp )
185+
186+ # Define pretrained model local path
187+ pretrained_model_data_local_path = os .path .join (DATA_DIR , "sparkml_model" , "mleap_model.tar.gz" )
188+
189+ with timeout (minutes = 30 ):
190+ model_data_prefix = os .path .join (
191+ "s3://" , sagemaker_session .default_bucket (), "multimodel-{}/" .format (timestamp )
192+ )
193+ multi_data_model = MultiDataModel (
194+ name = model_name ,
195+ model_data_prefix = model_data_prefix ,
196+ image = container_image ,
197+ role = ROLE ,
198+ sagemaker_session = sagemaker_session ,
199+ )
200+
201+ # Add model before deploy
202+ multi_data_model .add_model (pretrained_model_data_local_path , PRETRAINED_MODEL_PATH_1 )
203+ # Deploy model to an endpoint
204+ multi_data_model .deploy (1 , "local" , endpoint_name = endpoint_name )
205+ # Add models after deploy
206+ multi_data_model .add_model (pretrained_model_data_local_path , PRETRAINED_MODEL_PATH_2 )
207+
208+ endpoint_models = []
209+ for model_path in multi_data_model .list_models ():
210+ endpoint_models .append (model_path )
211+ assert PRETRAINED_MODEL_PATH_1 in endpoint_models
212+ assert PRETRAINED_MODEL_PATH_2 in endpoint_models
213+
214+ predictor = RealTimePredictor (
215+ endpoint = endpoint_name ,
216+ sagemaker_session = multi_data_model .sagemaker_session ,
217+ serializer = npy_serializer ,
218+ deserializer = string_deserializer ,
219+ )
220+
221+ data = numpy .zeros (shape = (1 , 1 , 28 , 28 ))
222+ result = predictor .predict (data , target_model = PRETRAINED_MODEL_PATH_1 )
223+ assert result == "Invoked model: {}" .format (PRETRAINED_MODEL_PATH_1 )
224+
225+ result = predictor .predict (data , target_model = PRETRAINED_MODEL_PATH_2 )
226+ assert result == "Invoked model: {}" .format (PRETRAINED_MODEL_PATH_2 )
227+
228+ # Cleanup
229+ multi_data_model .sagemaker_session .sagemaker_client .delete_endpoint_config (
230+ EndpointConfigName = endpoint_name
231+ )
232+ multi_data_model .sagemaker_session .delete_endpoint (endpoint_name )
233+ multi_data_model .delete_model ()
234+ with pytest .raises (Exception ) as exception :
235+ sagemaker_session .sagemaker_client .describe_model (ModelName = multi_data_model .name )
236+ assert "Could not find model" in str (exception .value )
237+ sagemaker_session .sagemaker_client .describe_endpoint_config (name = endpoint_name )
238+ assert "Could not find endpoint" in str (exception .value )
239+
240+
180241def test_multi_data_model_deploy_trained_model_from_framework_estimator (
181242 container_image , sagemaker_session , cpu_instance_type
182243):
0 commit comments