Skip to content

Commit 5e09ce5

Browse files
committed
fix: fix an issue where target_container_hostname not being available in predict
1 parent aba802c commit 5e09ce5

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

src/sagemaker/base_predictor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def predict(
154154
inference_id=None,
155155
custom_attributes=None,
156156
component_name: Optional[str] = None,
157+
target_container_hostname=None,
157158
):
158159
"""Return the inference from the specified endpoint.
159160
@@ -188,6 +189,9 @@ def predict(
188189
function (Default: None).
189190
component_name (str): Optional. Name of the Amazon SageMaker inference component
190191
corresponding the predictor.
192+
target_container_hostname (str): Optional. If the endpoint hosts multiple containers
193+
and is configured to use direct invocation, this parameter specifies the host name
194+
of the container to invoke. (Default: None).
191195
192196
Returns:
193197
object: Inference for the given input. If a deserializer was specified when creating
@@ -203,6 +207,7 @@ def predict(
203207
target_variant=target_variant,
204208
inference_id=inference_id,
205209
custom_attributes=custom_attributes,
210+
target_container_hostname=target_container_hostname,
206211
)
207212

208213
inference_component_name = component_name or self._get_component_name()

tests/unit/test_predictor.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,29 @@ def test_predict_call_with_inference_id():
135135
assert result == RETURN_VALUE
136136

137137

138+
def test_predict_call_with_target_container_hostname():
139+
sagemaker_session = empty_sagemaker_session()
140+
predictor = Predictor(ENDPOINT, sagemaker_session)
141+
142+
data = "untouched"
143+
result = predictor.predict(data, target_container_hostname="test_target_container_hostname")
144+
145+
assert sagemaker_session.sagemaker_runtime_client.invoke_endpoint.called
146+
147+
expected_request_args = {
148+
"Accept": DEFAULT_ACCEPT,
149+
"Body": data,
150+
"ContentType": DEFAULT_CONTENT_TYPE,
151+
"EndpointName": ENDPOINT,
152+
"TargetContainerHostname": "test_target_container_hostname",
153+
}
154+
155+
_, kwargs = sagemaker_session.sagemaker_runtime_client.invoke_endpoint.call_args
156+
assert kwargs == expected_request_args
157+
158+
assert result == RETURN_VALUE
159+
160+
138161
def test_multi_model_predict_call():
139162
sagemaker_session = empty_sagemaker_session()
140163
predictor = Predictor(ENDPOINT, sagemaker_session)

0 commit comments

Comments
 (0)