1515# limitations under the License.
1616#
1717
18+ import json
1819import logging
1920from collections .abc import Iterable
2021from collections .abc import Mapping
@@ -63,6 +64,7 @@ def __init__(
6364 experiment : Optional [str ] = None ,
6465 network : Optional [str ] = None ,
6566 private : bool = False ,
67+ invoke_route : Optional [str ] = None ,
6668 * ,
6769 min_batch_size : Optional [int ] = None ,
6870 max_batch_size : Optional [int ] = None ,
@@ -95,6 +97,12 @@ def __init__(
9597 private: optional. if the deployed Vertex AI endpoint is
9698 private, set to true. Requires a network to be provided
9799 as well.
100+ invoke_route: optional. the custom route path to use when invoking
101+ endpoints with arbitrary prediction routes. When specified, uses
102+ `Endpoint.invoke()` instead of `Endpoint.predict()`. The route
103+ should start with a forward slash, e.g., "/predict/v1".
104+ See https://cloud.google.com/vertex-ai/docs/predictions/use-arbitrary-custom-routes
105+ for more information.
98106 min_batch_size: optional. the minimum batch size to use when batching
99107 inputs.
100108 max_batch_size: optional. the maximum batch size to use when batching
@@ -104,6 +112,7 @@ def __init__(
104112 """
105113 self ._batching_kwargs = {}
106114 self ._env_vars = kwargs .get ('env_vars' , {})
115+ self ._invoke_route = invoke_route
107116 if min_batch_size is not None :
108117 self ._batching_kwargs ["min_batch_size" ] = min_batch_size
109118 if max_batch_size is not None :
@@ -203,9 +212,65 @@ def request(
203212 Returns:
204213 An iterable of Predictions.
205214 """
206- prediction = model .predict (instances = list (batch ), parameters = inference_args )
207- return utils ._convert_to_result (
208- batch , prediction .predictions , prediction .deployed_model_id )
215+ if self ._invoke_route :
216+ # Use invoke() for endpoints with custom prediction routes
217+ request_body = {"instances" : list (batch )}
218+ if inference_args :
219+ request_body ["parameters" ] = inference_args
220+ response = model .invoke (
221+ request_path = self ._invoke_route ,
222+ body = json .dumps (request_body ).encode ("utf-8" ),
223+ headers = {"Content-Type" : "application/json" })
224+ return self ._parse_invoke_response (batch , response )
225+ else :
226+ prediction = model .predict (
227+ instances = list (batch ), parameters = inference_args )
228+ return utils ._convert_to_result (
229+ batch , prediction .predictions , prediction .deployed_model_id )
230+
231+ def _parse_invoke_response (
232+ self , batch : Sequence [Any ],
233+ response : bytes ) -> Iterable [PredictionResult ]:
234+ """Parses the response from Endpoint.invoke() into PredictionResults.
235+
236+ Args:
237+ batch: the original batch of inputs.
238+ response: the raw bytes response from invoke().
239+
240+ Returns:
241+ An iterable of PredictionResults.
242+ """
243+ try :
244+ response_json = json .loads (response .decode ("utf-8" ))
245+ except (json .JSONDecodeError , UnicodeDecodeError ) as e :
246+ LOGGER .warning (
247+ "Failed to decode invoke response as JSON, returning raw bytes: %s" ,
248+ e )
249+ # Return raw response for each batch item
250+ return [
251+ PredictionResult (example = example , inference = response )
252+ for example in batch
253+ ]
254+
255+ # Handle standard Vertex AI response format with "predictions" key
256+ if isinstance (response_json , dict ) and "predictions" in response_json :
257+ predictions = response_json ["predictions" ]
258+ model_id = response_json .get ("deployedModelId" )
259+ return utils ._convert_to_result (batch , predictions , model_id )
260+
261+ # Handle response as a list of predictions (one per input)
262+ if isinstance (response_json , list ) and len (response_json ) == len (batch ):
263+ return utils ._convert_to_result (batch , response_json , None )
264+
265+ # Handle single prediction response
266+ if len (batch ) == 1 :
267+ return [PredictionResult (example = batch [0 ], inference = response_json )]
268+
269+ # Fallback: return the full response for each batch item
270+ return [
271+ PredictionResult (example = example , inference = response_json )
272+ for example in batch
273+ ]
209274
210275 def batch_elements_kwargs (self ) -> Mapping [str , Any ]:
211276 return self ._batching_kwargs
0 commit comments