1616#
1717
1818import logging
19- import time
19+ from collections .abc import Iterable
20+ from collections .abc import Mapping
21+ from collections .abc import Sequence
2022from typing import Any
21- from typing import Dict
22- from typing import Iterable
23- from typing import Mapping
2423from typing import Optional
25- from typing import Sequence
2624
2725from google .api_core .exceptions import ServerError
2826from google .api_core .exceptions import TooManyRequests
2927from google .cloud import aiplatform
3028
31- from apache_beam .io .components .adaptive_throttler import AdaptiveThrottler
32- from apache_beam .metrics .metric import Metrics
3329from apache_beam .ml .inference import utils
34- from apache_beam .ml .inference .base import ModelHandler
3530from apache_beam .ml .inference .base import PredictionResult
36- from apache_beam .utils import retry
37-
38- MSEC_TO_SEC = 1000
31+ from apache_beam .ml .inference .base import RemoteModelHandler
3932
4033LOGGER = logging .getLogger ("VertexAIModelHandlerJSON" )
4134
@@ -59,9 +52,9 @@ def _retry_on_appropriate_gcp_error(exception):
5952 return isinstance (exception , (TooManyRequests , ServerError ))
6053
6154
62- class VertexAIModelHandlerJSON (ModelHandler [Any ,
63- PredictionResult ,
64- aiplatform .Endpoint ]):
55+ class VertexAIModelHandlerJSON (RemoteModelHandler [Any ,
56+ PredictionResult ,
57+ aiplatform .Endpoint ]):
6558 def __init__ (
6659 self ,
6760 endpoint_id : str ,
@@ -139,14 +132,10 @@ def __init__(
139132 _ = self ._retrieve_endpoint (
140133 self .endpoint_name , self .location , self .is_private )
141134
142- # Configure AdaptiveThrottler and throttling metrics for client-side
143- # throttling behavior.
144- # See https://docs.google.com/document/d/1ePorJGZnLbNCmLD9mR7iFYOdPsyDA1rDnTpYnbdrzSU/edit?usp=sharing
145- # for more details.
146- self .throttled_secs = Metrics .counter (
147- VertexAIModelHandlerJSON , "cumulativeThrottlingSeconds" )
148- self .throttler = AdaptiveThrottler (
149- window_ms = 1 , bucket_ms = 1 , overload_ratio = 2 )
135+ super ().__init__ (
136+ namespace = 'VertexAIModelHandlerJSON' ,
137+ retry_filter = _retry_on_appropriate_gcp_error ,
138+ ** kwargs )
150139
151140 def _retrieve_endpoint (
152141 self , endpoint_id : str , location : str ,
@@ -183,7 +172,7 @@ def _retrieve_endpoint(
183172
184173 return endpoint
185174
186- def load_model (self ) -> aiplatform .Endpoint :
175+ def create_client (self ) -> aiplatform .Endpoint :
187176 """Loads the Endpoint object used to build and send prediction request to
188177 Vertex AI.
189178 """
@@ -193,39 +182,11 @@ def load_model(self) -> aiplatform.Endpoint:
193182 self .endpoint_name , self .location , self .is_private )
194183 return ep
195184
196- @retry .with_exponential_backoff (
197- num_retries = 5 , retry_filter = _retry_on_appropriate_gcp_error )
198- def get_request (
199- self ,
200- batch : Sequence [Any ],
201- model : aiplatform .Endpoint ,
202- throttle_delay_secs : int ,
203- inference_args : Optional [Dict [str , Any ]]):
204- while self .throttler .throttle_request (time .time () * MSEC_TO_SEC ):
205- LOGGER .info (
206- "Delaying request for %d seconds due to previous failures" ,
207- throttle_delay_secs )
208- time .sleep (throttle_delay_secs )
209- self .throttled_secs .inc (throttle_delay_secs )
210-
211- try :
212- req_time = time .time ()
213- prediction = model .predict (
214- instances = list (batch ), parameters = inference_args )
215- self .throttler .successful_request (req_time * MSEC_TO_SEC )
216- return prediction
217- except TooManyRequests as e :
218- LOGGER .warning ("request was limited by the service with code %i" , e .code )
219- raise
220- except Exception as e :
221- LOGGER .error ("unexpected exception raised as part of request, got %s" , e )
222- raise
223-
224- def run_inference (
185+ def request (
225186 self ,
226187 batch : Sequence [Any ],
227188 model : aiplatform .Endpoint ,
228- inference_args : Optional [Dict [str , Any ]] = None
189+ inference_args : Optional [dict [str , Any ]] = None
229190 ) -> Iterable [PredictionResult ]:
230191 """ Sends a prediction request to a Vertex AI endpoint containing batch
231192 of inputs and matches that input with the prediction response from
@@ -242,16 +203,11 @@ def run_inference(
242203 Returns:
243204 An iterable of Predictions.
244205 """
245-
246- # Endpoint.predict returns a Prediction type with the prediction values
247- # along with model metadata
248- prediction = self .get_request (
249- batch , model , throttle_delay_secs = 5 , inference_args = inference_args )
250-
206+ prediction = model .predict (instances = list (batch ), parameters = inference_args )
251207 return utils ._convert_to_result (
252208 batch , prediction .predictions , prediction .deployed_model_id )
253209
254- def validate_inference_args (self , inference_args : Optional [Dict [str , Any ]]):
210+ def validate_inference_args (self , inference_args : Optional [dict [str , Any ]]):
255211 pass
256212
257213 def batch_elements_kwargs (self ) -> Mapping [str , Any ]:
0 commit comments