Skip to content

Commit cbf45e7

Browse files
committed
feat: Add support for custom prediction routes in Vertex AI inference using the invoke_route parameter and custom response parsing.
1 parent 423a3c3 commit cbf45e7

File tree

3 files changed

+141
-3
lines changed

3 files changed

+141
-3
lines changed

sdks/python/apache_beam/ml/inference/vertex_ai_inference.py

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# limitations under the License.
1616
#
1717

18+
import json
1819
import logging
1920
from collections.abc import Iterable
2021
from collections.abc import Mapping
@@ -63,6 +64,7 @@ def __init__(
6364
experiment: Optional[str] = None,
6465
network: Optional[str] = None,
6566
private: bool = False,
67+
invoke_route: Optional[str] = None,
6668
*,
6769
min_batch_size: Optional[int] = None,
6870
max_batch_size: Optional[int] = None,
@@ -95,6 +97,12 @@ def __init__(
9597
private: optional. if the deployed Vertex AI endpoint is
9698
private, set to true. Requires a network to be provided
9799
as well.
100+
invoke_route: optional. the custom route path to use when invoking
101+
endpoints with arbitrary prediction routes. When specified, uses
102+
`Endpoint.invoke()` instead of `Endpoint.predict()`. The route
103+
should start with a forward slash, e.g., "/predict/v1".
104+
See https://cloud.google.com/vertex-ai/docs/predictions/use-arbitrary-custom-routes
105+
for more information.
98106
min_batch_size: optional. the minimum batch size to use when batching
99107
inputs.
100108
max_batch_size: optional. the maximum batch size to use when batching
@@ -104,6 +112,7 @@ def __init__(
104112
"""
105113
self._batching_kwargs = {}
106114
self._env_vars = kwargs.get('env_vars', {})
115+
self._invoke_route = invoke_route
107116
if min_batch_size is not None:
108117
self._batching_kwargs["min_batch_size"] = min_batch_size
109118
if max_batch_size is not None:
@@ -203,9 +212,65 @@ def request(
203212
Returns:
204213
An iterable of Predictions.
205214
"""
206-
prediction = model.predict(instances=list(batch), parameters=inference_args)
207-
return utils._convert_to_result(
208-
batch, prediction.predictions, prediction.deployed_model_id)
215+
if self._invoke_route:
216+
# Use invoke() for endpoints with custom prediction routes
217+
request_body = {"instances": list(batch)}
218+
if inference_args:
219+
request_body["parameters"] = inference_args
220+
response = model.invoke(
221+
request_path=self._invoke_route,
222+
body=json.dumps(request_body).encode("utf-8"),
223+
headers={"Content-Type": "application/json"})
224+
return self._parse_invoke_response(batch, response)
225+
else:
226+
prediction = model.predict(
227+
instances=list(batch), parameters=inference_args)
228+
return utils._convert_to_result(
229+
batch, prediction.predictions, prediction.deployed_model_id)
230+
231+
def _parse_invoke_response(
232+
self, batch: Sequence[Any],
233+
response: bytes) -> Iterable[PredictionResult]:
234+
"""Parses the response from Endpoint.invoke() into PredictionResults.
235+
236+
Args:
237+
batch: the original batch of inputs.
238+
response: the raw bytes response from invoke().
239+
240+
Returns:
241+
An iterable of PredictionResults.
242+
"""
243+
try:
244+
response_json = json.loads(response.decode("utf-8"))
245+
except (json.JSONDecodeError, UnicodeDecodeError) as e:
246+
LOGGER.warning(
247+
"Failed to decode invoke response as JSON, returning raw bytes: %s",
248+
e)
249+
# Return raw response for each batch item
250+
return [
251+
PredictionResult(example=example, inference=response)
252+
for example in batch
253+
]
254+
255+
# Handle standard Vertex AI response format with "predictions" key
256+
if isinstance(response_json, dict) and "predictions" in response_json:
257+
predictions = response_json["predictions"]
258+
model_id = response_json.get("deployedModelId")
259+
return utils._convert_to_result(batch, predictions, model_id)
260+
261+
# Handle response as a list of predictions (one per input)
262+
if isinstance(response_json, list) and len(response_json) == len(batch):
263+
return utils._convert_to_result(batch, response_json, None)
264+
265+
# Handle single prediction response
266+
if len(batch) == 1:
267+
return [PredictionResult(example=batch[0], inference=response_json)]
268+
269+
# Fallback: return the full response for each batch item
270+
return [
271+
PredictionResult(example=example, inference=response_json)
272+
for example in batch
273+
]
209274

210275
def batch_elements_kwargs(self) -> Mapping[str, Any]:
211276
return self._batching_kwargs

sdks/python/apache_beam/ml/inference/vertex_ai_inference_test.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,5 +48,70 @@ def test_exception_on_private_without_network(self):
4848
private=True)
4949

5050

51+
class ParseInvokeResponseTest(unittest.TestCase):
52+
"""Tests for _parse_invoke_response method."""
53+
54+
def _create_handler_with_invoke_route(self, invoke_route="/test"):
55+
"""Creates a mock handler with invoke_route for testing."""
56+
import unittest.mock as mock
57+
with mock.patch.object(
58+
VertexAIModelHandlerJSON, '_retrieve_endpoint', return_value=None):
59+
handler = VertexAIModelHandlerJSON(
60+
endpoint_id="1",
61+
project="testproject",
62+
location="us-central1",
63+
invoke_route=invoke_route)
64+
return handler
65+
66+
def test_parse_invoke_response_with_predictions_key(self):
67+
"""Test parsing response with standard 'predictions' key."""
68+
handler = self._create_handler_with_invoke_route()
69+
batch = [{"input": "test1"}, {"input": "test2"}]
70+
response = b'{"predictions": ["result1", "result2"], "deployedModelId": "model123"}'
71+
72+
results = list(handler._parse_invoke_response(batch, response))
73+
74+
self.assertEqual(len(results), 2)
75+
self.assertEqual(results[0].example, {"input": "test1"})
76+
self.assertEqual(results[0].inference, "result1")
77+
self.assertEqual(results[1].example, {"input": "test2"})
78+
self.assertEqual(results[1].inference, "result2")
79+
80+
def test_parse_invoke_response_list_format(self):
81+
"""Test parsing response as a list of predictions."""
82+
handler = self._create_handler_with_invoke_route()
83+
batch = [{"input": "test1"}, {"input": "test2"}]
84+
response = b'["result1", "result2"]'
85+
86+
results = list(handler._parse_invoke_response(batch, response))
87+
88+
self.assertEqual(len(results), 2)
89+
self.assertEqual(results[0].inference, "result1")
90+
self.assertEqual(results[1].inference, "result2")
91+
92+
def test_parse_invoke_response_single_prediction(self):
93+
"""Test parsing response with a single prediction."""
94+
handler = self._create_handler_with_invoke_route()
95+
batch = [{"input": "test1"}]
96+
response = b'{"output": "single result"}'
97+
98+
results = list(handler._parse_invoke_response(batch, response))
99+
100+
self.assertEqual(len(results), 1)
101+
self.assertEqual(results[0].inference, {"output": "single result"})
102+
103+
def test_parse_invoke_response_non_json(self):
104+
"""Test handling non-JSON response."""
105+
handler = self._create_handler_with_invoke_route()
106+
batch = [{"input": "test1"}]
107+
response = b'not valid json'
108+
109+
results = list(handler._parse_invoke_response(batch, response))
110+
111+
self.assertEqual(len(results), 1)
112+
self.assertEqual(results[0].inference, response)
113+
114+
51115
if __name__ == '__main__':
52116
unittest.main()
117+

sdks/python/apache_beam/yaml/yaml_ml.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ def __init__(
168168
experiment: Optional[str] = None,
169169
network: Optional[str] = None,
170170
private: bool = False,
171+
invoke_route: Optional[str] = None,
171172
min_batch_size: Optional[int] = None,
172173
max_batch_size: Optional[int] = None,
173174
max_batch_duration_secs: Optional[int] = None):
@@ -236,6 +237,12 @@ def __init__(
236237
private: If the deployed Vertex AI endpoint is
237238
private, set to true. Requires a network to be provided
238239
as well.
240+
invoke_route: The custom route path to use when invoking
241+
endpoints with arbitrary prediction routes. When specified, uses
242+
`Endpoint.invoke()` instead of `Endpoint.predict()`. The route
243+
should start with a forward slash, e.g., "/predict/v1".
244+
See https://cloud.google.com/vertex-ai/docs/predictions/use-arbitrary-custom-routes
245+
for more information.
239246
min_batch_size: The minimum batch size to use when batching
240247
inputs.
241248
max_batch_size: The maximum batch size to use when batching
@@ -258,6 +265,7 @@ def __init__(
258265
experiment=experiment,
259266
network=network,
260267
private=private,
268+
invoke_route=invoke_route,
261269
min_batch_size=min_batch_size,
262270
max_batch_size=max_batch_size,
263271
max_batch_duration_secs=max_batch_duration_secs)

0 commit comments

Comments
 (0)