@@ -83,7 +83,7 @@ def __init__(
8383 self ._endpoint_config_name = self ._get_endpoint_config_name ()
8484 self ._model_names = self ._get_model_names ()
8585
86- def predict (self , data , initial_args = None , target_model = None , target_variant = None ):
86+ def predict (self , data , initial_args = None , target_model = None ):
8787 """Return the inference from the specified endpoint.
8888
8989 Args:
@@ -98,9 +98,6 @@ def predict(self, data, initial_args=None, target_model=None, target_variant=Non
9898 target_model (str): S3 model artifact path to run an inference request on,
9999 in case of a multi model endpoint. Does not apply to endpoints hosting
100100 single model (Default: None)
101- target_variant (str): The name of the production variant to run an inference
102- request on (Default: None). Note that the ProductionVariant identifies the model
103- you want to host and the resources you want to deploy for hosting it.
104101
105102 Returns:
106103 object: Inference for the given input. If a deserializer was specified when creating
@@ -109,7 +106,7 @@ def predict(self, data, initial_args=None, target_model=None, target_variant=Non
109106 as is.
110107 """
111108
112- request_args = self ._create_request_args (data , initial_args , target_model , target_variant )
109+ request_args = self ._create_request_args (data , initial_args , target_model )
113110 response = self .sagemaker_session .sagemaker_runtime_client .invoke_endpoint (** request_args )
114111 return self ._handle_response (response )
115112
@@ -126,13 +123,12 @@ def _handle_response(self, response):
126123 response_body .close ()
127124 return data
128125
129- def _create_request_args (self , data , initial_args = None , target_model = None , target_variant = None ):
126+ def _create_request_args (self , data , initial_args = None , target_model = None ):
130127 """
131128 Args:
132129 data:
133130 initial_args:
134131 target_model:
135- target_variant:
136132 """
137133 args = dict (initial_args ) if initial_args else {}
138134
@@ -148,9 +144,6 @@ def _create_request_args(self, data, initial_args=None, target_model=None, targe
148144 if target_model :
149145 args ["TargetModel" ] = target_model
150146
151- if target_variant :
152- args ["TargetVariant" ] = target_variant
153-
154147 if self .serializer is not None :
155148 data = self .serializer (data )
156149
0 commit comments