Skip to content

Commit 9e6bb25

Browse files
committed
Migrate Vertex AI Model Handler to new base class
1 parent bb29415 commit 9e6bb25

File tree

2 files changed

+18
-61
lines changed

2 files changed

+18
-61
lines changed

sdks/python/apache_beam/io/requestresponse.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
from apache_beam.coders import coders
4444
from apache_beam.io.components.adaptive_throttler import AdaptiveThrottler
4545
from apache_beam.metrics import Metrics
46-
from apache_beam.ml.inference.vertex_ai_inference import MSEC_TO_SEC
4746
from apache_beam.transforms.util import BatchElements
4847
from apache_beam.utils import retry
4948

@@ -58,6 +57,8 @@
5857
# for cache record.
5958
DEFAULT_CACHE_ENTRY_TTL_SEC = 24 * 60 * 60
6059

60+
MSEC_TO_SEC = 1000
61+
6162
_LOGGER = logging.getLogger(__name__)
6263

6364
__all__ = [

sdks/python/apache_beam/ml/inference/vertex_ai_inference.py

Lines changed: 16 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -16,26 +16,19 @@
1616
#
1717

1818
import logging
19-
import time
19+
from collections.abc import Iterable
20+
from collections.abc import Mapping
21+
from collections.abc import Sequence
2022
from typing import Any
21-
from typing import Dict
22-
from typing import Iterable
23-
from typing import Mapping
2423
from typing import Optional
25-
from typing import Sequence
2624

2725
from google.api_core.exceptions import ServerError
2826
from google.api_core.exceptions import TooManyRequests
2927
from google.cloud import aiplatform
3028

31-
from apache_beam.io.components.adaptive_throttler import AdaptiveThrottler
32-
from apache_beam.metrics.metric import Metrics
3329
from apache_beam.ml.inference import utils
34-
from apache_beam.ml.inference.base import ModelHandler
3530
from apache_beam.ml.inference.base import PredictionResult
36-
from apache_beam.utils import retry
37-
38-
MSEC_TO_SEC = 1000
31+
from apache_beam.ml.inference.base import RemoteModelHandler
3932

4033
LOGGER = logging.getLogger("VertexAIModelHandlerJSON")
4134

@@ -59,9 +52,9 @@ def _retry_on_appropriate_gcp_error(exception):
5952
return isinstance(exception, (TooManyRequests, ServerError))
6053

6154

62-
class VertexAIModelHandlerJSON(ModelHandler[Any,
63-
PredictionResult,
64-
aiplatform.Endpoint]):
55+
class VertexAIModelHandlerJSON(RemoteModelHandler[Any,
56+
PredictionResult,
57+
aiplatform.Endpoint]):
6558
def __init__(
6659
self,
6760
endpoint_id: str,
@@ -139,14 +132,10 @@ def __init__(
139132
_ = self._retrieve_endpoint(
140133
self.endpoint_name, self.location, self.is_private)
141134

142-
# Configure AdaptiveThrottler and throttling metrics for client-side
143-
# throttling behavior.
144-
# See https://docs.google.com/document/d/1ePorJGZnLbNCmLD9mR7iFYOdPsyDA1rDnTpYnbdrzSU/edit?usp=sharing
145-
# for more details.
146-
self.throttled_secs = Metrics.counter(
147-
VertexAIModelHandlerJSON, "cumulativeThrottlingSeconds")
148-
self.throttler = AdaptiveThrottler(
149-
window_ms=1, bucket_ms=1, overload_ratio=2)
135+
super().__init__(
136+
namespace='VertexAIModelHandlerJSON',
137+
retry_filter=_retry_on_appropriate_gcp_error,
138+
**kwargs)
150139

151140
def _retrieve_endpoint(
152141
self, endpoint_id: str, location: str,
@@ -183,7 +172,7 @@ def _retrieve_endpoint(
183172

184173
return endpoint
185174

186-
def load_model(self) -> aiplatform.Endpoint:
175+
def create_client(self) -> aiplatform.Endpoint:
187176
"""Loads the Endpoint object used to build and send prediction request to
188177
Vertex AI.
189178
"""
@@ -193,39 +182,11 @@ def load_model(self) -> aiplatform.Endpoint:
193182
self.endpoint_name, self.location, self.is_private)
194183
return ep
195184

196-
@retry.with_exponential_backoff(
197-
num_retries=5, retry_filter=_retry_on_appropriate_gcp_error)
198-
def get_request(
199-
self,
200-
batch: Sequence[Any],
201-
model: aiplatform.Endpoint,
202-
throttle_delay_secs: int,
203-
inference_args: Optional[Dict[str, Any]]):
204-
while self.throttler.throttle_request(time.time() * MSEC_TO_SEC):
205-
LOGGER.info(
206-
"Delaying request for %d seconds due to previous failures",
207-
throttle_delay_secs)
208-
time.sleep(throttle_delay_secs)
209-
self.throttled_secs.inc(throttle_delay_secs)
210-
211-
try:
212-
req_time = time.time()
213-
prediction = model.predict(
214-
instances=list(batch), parameters=inference_args)
215-
self.throttler.successful_request(req_time * MSEC_TO_SEC)
216-
return prediction
217-
except TooManyRequests as e:
218-
LOGGER.warning("request was limited by the service with code %i", e.code)
219-
raise
220-
except Exception as e:
221-
LOGGER.error("unexpected exception raised as part of request, got %s", e)
222-
raise
223-
224-
def run_inference(
185+
def request(
225186
self,
226187
batch: Sequence[Any],
227188
model: aiplatform.Endpoint,
228-
inference_args: Optional[Dict[str, Any]] = None
189+
inference_args: Optional[dict[str, Any]] = None
229190
) -> Iterable[PredictionResult]:
230191
""" Sends a prediction request to a Vertex AI endpoint containing batch
231192
of inputs and matches that input with the prediction response from
@@ -242,16 +203,11 @@ def run_inference(
242203
Returns:
243204
An iterable of Predictions.
244205
"""
245-
246-
# Endpoint.predict returns a Prediction type with the prediction values
247-
# along with model metadata
248-
prediction = self.get_request(
249-
batch, model, throttle_delay_secs=5, inference_args=inference_args)
250-
206+
prediction = model.predict(instances=list(batch), parameters=inference_args)
251207
return utils._convert_to_result(
252208
batch, prediction.predictions, prediction.deployed_model_id)
253209

254-
def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
210+
def validate_inference_args(self, inference_args: Optional[dict[str, Any]]):
255211
pass
256212

257213
def batch_elements_kwargs(self) -> Mapping[str, Any]:

0 commit comments

Comments
 (0)