@@ -91,12 +91,17 @@ def __init__(
9191 self .sagemaker_session = sagemaker_session or Session ()
9292 self .serializer = serializer
9393 self .deserializer = deserializer
94- self ._endpoint_config_name = self . _get_endpoint_config_name ()
95- self ._model_names = self . _get_model_names ()
94+ self ._endpoint_config_name = None
95+ self ._model_names = None
9696 self ._context = None
9797
9898 def predict (
99- self , data , initial_args = None , target_model = None , target_variant = None , inference_id = None
99+ self ,
100+ data ,
101+ initial_args = None ,
102+ target_model = None ,
103+ target_variant = None ,
104+ inference_id = None ,
100105 ):
101106 """Return the inference from the specified endpoint.
102107
@@ -138,7 +143,12 @@ def _handle_response(self, response):
138143 return self .deserializer .deserialize (response_body , content_type )
139144
140145 def _create_request_args (
141- self , data , initial_args = None , target_model = None , target_variant = None , inference_id = None
146+ self ,
147+ data ,
148+ initial_args = None ,
149+ target_model = None ,
150+ target_variant = None ,
151+ inference_id = None ,
142152 ):
143153 """Placeholder docstring"""
144154 args = dict (initial_args ) if initial_args else {}
@@ -223,24 +233,30 @@ def update_endpoint(
223233 associated with the endpoint.
224234 """
225235 production_variants = None
236+ current_model_names = self ._get_model_names ()
226237
227238 if initial_instance_count or instance_type or accelerator_type or model_name :
228239 if instance_type is None or initial_instance_count is None :
229240 raise ValueError (
230241 "Missing initial_instance_count and/or instance_type. Provided values: "
231242 "initial_instance_count={}, instance_type={}, accelerator_type={}, "
232243 "model_name={}." .format (
233- initial_instance_count , instance_type , accelerator_type , model_name
244+ initial_instance_count ,
245+ instance_type ,
246+ accelerator_type ,
247+ model_name ,
234248 )
235249 )
236250
237251 if model_name is None :
238- if len (self . _model_names ) > 1 :
252+ if len (current_model_names ) > 1 :
239253 raise ValueError (
240254 "Unable to choose a default model for a new EndpointConfig because "
241- "the endpoint has multiple models: {}" .format (", " .join (self ._model_names ))
255+ "the endpoint has multiple models: {}" .format (
256+ ", " .join (current_model_names )
257+ )
242258 )
243- model_name = self . _model_names [0 ]
259+ model_name = current_model_names [0 ]
244260 else :
245261 self ._model_names = [model_name ]
246262
@@ -252,9 +268,10 @@ def update_endpoint(
252268 )
253269 production_variants = [production_variant_config ]
254270
255- new_endpoint_config_name = name_from_base (self ._endpoint_config_name )
271+ current_endpoint_config_name = self ._get_endpoint_config_name ()
272+ new_endpoint_config_name = name_from_base (current_endpoint_config_name )
256273 self .sagemaker_session .create_endpoint_config_from_existing (
257- self . _endpoint_config_name ,
274+ current_endpoint_config_name ,
258275 new_endpoint_config_name ,
259276 new_tags = tags ,
260277 new_kms_key = kms_key ,
@@ -268,7 +285,8 @@ def update_endpoint(
268285
269286 def _delete_endpoint_config (self ):
270287 """Delete the Amazon SageMaker endpoint configuration"""
271- self .sagemaker_session .delete_endpoint_config (self ._endpoint_config_name )
288+ current_endpoint_config_name = self ._get_endpoint_config_name ()
289+ self .sagemaker_session .delete_endpoint_config (current_endpoint_config_name )
272290
273291 def delete_endpoint (self , delete_endpoint_config = True ):
274292 """Delete the Amazon SageMaker endpoint backing this predictor.
@@ -291,7 +309,8 @@ def delete_model(self):
291309 """Deletes the Amazon SageMaker models backing this predictor."""
292310 request_failed = False
293311 failed_models = []
294- for model_name in self ._model_names :
312+ current_model_names = self ._get_model_names ()
313+ for model_name in current_model_names :
295314 try :
296315 self .sagemaker_session .delete_model (model_name )
297316 except Exception : # pylint: disable=broad-except
@@ -460,26 +479,33 @@ def endpoint_context(self):
460479 if len (contexts ) != 0 :
461480 # create endpoint context object
462481 self ._context = EndpointContext .load (
463- sagemaker_session = self .sagemaker_session , context_name = contexts [0 ].context_name
482+ sagemaker_session = self .sagemaker_session ,
483+ context_name = contexts [0 ].context_name ,
464484 )
465485
466486 return self ._context
467487
468488 def _get_endpoint_config_name (self ):
469489 """Placeholder docstring"""
490+ if self ._endpoint_config_name is not None :
491+ return self ._endpoint_config_name
470492 endpoint_desc = self .sagemaker_session .sagemaker_client .describe_endpoint (
471493 EndpointName = self .endpoint_name
472494 )
473- endpoint_config_name = endpoint_desc ["EndpointConfigName" ]
474- return endpoint_config_name
495+ self . _endpoint_config_name = endpoint_desc ["EndpointConfigName" ]
496+ return self . _endpoint_config_name
475497
476498 def _get_model_names (self ):
477499 """Placeholder docstring"""
500+ if self ._model_names is not None :
501+ return self ._model_names
502+ current_endpoint_config_name = self ._get_endpoint_config_name ()
478503 endpoint_config = self .sagemaker_session .sagemaker_client .describe_endpoint_config (
479- EndpointConfigName = self . _endpoint_config_name
504+ EndpointConfigName = current_endpoint_config_name
480505 )
481506 production_variants = endpoint_config ["ProductionVariants" ]
482- return [d ["ModelName" ] for d in production_variants ]
507+ self ._model_names = [d ["ModelName" ] for d in production_variants ]
508+ return self ._model_names
483509
484510 @property
485511 def content_type (self ):
0 commit comments