File tree Expand file tree Collapse file tree 2 files changed +28
-0
lines changed Expand file tree Collapse file tree 2 files changed +28
-0
lines changed Original file line number Diff line number Diff line change @@ -154,6 +154,7 @@ def predict(
154
154
inference_id = None ,
155
155
custom_attributes = None ,
156
156
component_name : Optional [str ] = None ,
157
+ target_container_hostname = None ,
157
158
):
158
159
"""Return the inference from the specified endpoint.
159
160
@@ -188,6 +189,9 @@ def predict(
188
189
function (Default: None).
189
190
component_name (str): Optional. Name of the Amazon SageMaker inference component
190
191
corresponding the predictor.
192
+ target_container_hostname (str): Optional. If the endpoint hosts multiple containers
193
+ and is configured to use direct invocation, this parameter specifies the host name
194
+ of the container to invoke. (Default: None).
191
195
192
196
Returns:
193
197
object: Inference for the given input. If a deserializer was specified when creating
@@ -203,6 +207,7 @@ def predict(
203
207
target_variant = target_variant ,
204
208
inference_id = inference_id ,
205
209
custom_attributes = custom_attributes ,
210
+ target_container_hostname = target_container_hostname ,
206
211
)
207
212
208
213
inference_component_name = component_name or self ._get_component_name ()
Original file line number Diff line number Diff line change @@ -135,6 +135,29 @@ def test_predict_call_with_inference_id():
135
135
assert result == RETURN_VALUE
136
136
137
137
138
+ def test_predict_call_with_target_container_hostname ():
139
+ sagemaker_session = empty_sagemaker_session ()
140
+ predictor = Predictor (ENDPOINT , sagemaker_session )
141
+
142
+ data = "untouched"
143
+ result = predictor .predict (data , target_container_hostname = "test_target_container_hostname" )
144
+
145
+ assert sagemaker_session .sagemaker_runtime_client .invoke_endpoint .called
146
+
147
+ expected_request_args = {
148
+ "Accept" : DEFAULT_ACCEPT ,
149
+ "Body" : data ,
150
+ "ContentType" : DEFAULT_CONTENT_TYPE ,
151
+ "EndpointName" : ENDPOINT ,
152
+ "TargetContainerHostname" : "test_target_container_hostname" ,
153
+ }
154
+
155
+ _ , kwargs = sagemaker_session .sagemaker_runtime_client .invoke_endpoint .call_args
156
+ assert kwargs == expected_request_args
157
+
158
+ assert result == RETURN_VALUE
159
+
160
+
138
161
def test_multi_model_predict_call ():
139
162
sagemaker_session = empty_sagemaker_session ()
140
163
predictor = Predictor (ENDPOINT , sagemaker_session )
You can’t perform that action at this time.
0 commit comments