Skip to content

Commit 3302cfa

Browse files
authored
Update Vertex AI embedding handlers to use RemoteModelHandler (#35670)
* Update Vertex AI embedding handlers to use RemoteModelHandler * Remove unused imports * fix incorrect type
1 parent 031cc2d commit 3302cfa

File tree

1 file changed

+32
-97
lines changed
  • sdks/python/apache_beam/ml/transforms/embeddings

1 file changed

+32
-97
lines changed

sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py

Lines changed: 32 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020
# to install Vertex AI Python SDK.
2121

2222
import logging
23-
import time
24-
from collections.abc import Iterable
2523
from collections.abc import Sequence
2624
from typing import Any
2725
from typing import Optional
@@ -32,14 +30,12 @@
3230

3331
import apache_beam as beam
3432
import vertexai
35-
from apache_beam.io.components.adaptive_throttler import AdaptiveThrottler
36-
from apache_beam.metrics.metric import Metrics
3733
from apache_beam.ml.inference.base import ModelHandler
34+
from apache_beam.ml.inference.base import RemoteModelHandler
3835
from apache_beam.ml.inference.base import RunInference
3936
from apache_beam.ml.transforms.base import EmbeddingsManager
4037
from apache_beam.ml.transforms.base import _ImageEmbeddingHandler
4138
from apache_beam.ml.transforms.base import _TextEmbeddingHandler
42-
from apache_beam.utils import retry
4339
from vertexai.language_models import TextEmbeddingInput
4440
from vertexai.language_models import TextEmbeddingModel
4541
from vertexai.vision_models import Image
@@ -80,7 +76,7 @@ def _retry_on_appropriate_gcp_error(exception):
8076
return isinstance(exception, (TooManyRequests, ServerError))
8177

8278

83-
class _VertexAITextEmbeddingHandler(ModelHandler):
79+
class _VertexAITextEmbeddingHandler(RemoteModelHandler):
8480
"""
8581
Note: Intended for internal use and guarantees no backwards compatibility.
8682
"""
@@ -92,7 +88,7 @@ def __init__(
9288
project: Optional[str] = None,
9389
location: Optional[str] = None,
9490
credentials: Optional[Credentials] = None,
95-
):
91+
**kwargs):
9692
vertexai.init(project=project, location=location, credentials=credentials)
9793
self.model_name = model_name
9894
if task_type not in TASK_TYPE_INPUTS:
@@ -101,47 +97,16 @@ def __init__(
10197
self.task_type = task_type
10298
self.title = title
10399

104-
# Configure AdaptiveThrottler and throttling metrics for client-side
105-
# throttling behavior.
106-
# See https://docs.google.com/document/d/1ePorJGZnLbNCmLD9mR7iFYOdPsyDA1rDnTpYnbdrzSU/edit?usp=sharing
107-
# for more details.
108-
self.throttled_secs = Metrics.counter(
109-
VertexAIImageEmbeddings, "cumulativeThrottlingSeconds")
110-
self.throttler = AdaptiveThrottler(
111-
window_ms=1, bucket_ms=1, overload_ratio=2)
112-
113-
@retry.with_exponential_backoff(
114-
num_retries=5, retry_filter=_retry_on_appropriate_gcp_error)
115-
def get_request(
116-
self,
117-
text_batch: Sequence[TextEmbeddingInput],
118-
model: TextEmbeddingModel,
119-
throttle_delay_secs: int):
120-
while self.throttler.throttle_request(time.time() * _MSEC_TO_SEC):
121-
LOGGER.info(
122-
"Delaying request for %d seconds due to previous failures",
123-
throttle_delay_secs)
124-
time.sleep(throttle_delay_secs)
125-
self.throttled_secs.inc(throttle_delay_secs)
126-
127-
try:
128-
req_time = time.time()
129-
prediction = model.get_embeddings(list(text_batch))
130-
self.throttler.successful_request(req_time * _MSEC_TO_SEC)
131-
return prediction
132-
except TooManyRequests as e:
133-
LOGGER.warning("request was limited by the service with code %i", e.code)
134-
raise
135-
except Exception as e:
136-
LOGGER.error("unexpected exception raised as part of request, got %s", e)
137-
raise
138-
139-
def run_inference(
100+
super().__init__(
101+
namespace='VertexAITextEmbeddingHandler',
102+
retry_filter=_retry_on_appropriate_gcp_error,
103+
**kwargs)
104+
105+
def request(
140106
self,
141107
batch: Sequence[str],
142-
model: Any,
143-
inference_args: Optional[dict[str, Any]] = None,
144-
) -> Iterable:
108+
model: TextEmbeddingModel,
109+
inference_args: Optional[dict[str, Any]] = None):
145110
embeddings = []
146111
batch_size = _BATCH_SIZE
147112
for i in range(0, len(batch), batch_size):
@@ -151,12 +116,11 @@ def run_inference(
151116
text=text, title=self.title, task_type=self.task_type)
152117
for text in text_batch_strs
153118
]
154-
embeddings_batch = self.get_request(
155-
text_batch=text_batch, model=model, throttle_delay_secs=5)
119+
embeddings_batch = model.get_embeddings(list(text_batch))
156120
embeddings.extend([el.values for el in embeddings_batch])
157121
return embeddings
158122

159-
def load_model(self):
123+
def create_client(self) -> TextEmbeddingModel:
160124
model = TextEmbeddingModel.from_pretrained(self.model_name)
161125
return model
162126

@@ -205,6 +169,7 @@ def __init__(
205169
self.credentials = credentials
206170
self.title = title
207171
self.task_type = task_type
172+
self.kwargs = kwargs
208173
super().__init__(columns=columns, **kwargs)
209174

210175
def get_model_handler(self) -> ModelHandler:
@@ -215,76 +180,45 @@ def get_model_handler(self) -> ModelHandler:
215180
credentials=self.credentials,
216181
title=self.title,
217182
task_type=self.task_type,
218-
)
183+
**self.kwargs)
219184

220185
def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform:
221186
return RunInference(
222187
model_handler=_TextEmbeddingHandler(self),
223188
inference_args=self.inference_args)
224189

225190

226-
class _VertexAIImageEmbeddingHandler(ModelHandler):
191+
class _VertexAIImageEmbeddingHandler(RemoteModelHandler):
227192
def __init__(
228193
self,
229194
model_name: str,
230195
dimension: Optional[int] = None,
231196
project: Optional[str] = None,
232197
location: Optional[str] = None,
233198
credentials: Optional[Credentials] = None,
234-
):
199+
**kwargs):
235200
vertexai.init(project=project, location=location, credentials=credentials)
236201
self.model_name = model_name
237202
self.dimension = dimension
238203

239-
# Configure AdaptiveThrottler and throttling metrics for client-side
240-
# throttling behavior.
241-
# See https://docs.google.com/document/d/1ePorJGZnLbNCmLD9mR7iFYOdPsyDA1rDnTpYnbdrzSU/edit?usp=sharing
242-
# for more details.
243-
self.throttled_secs = Metrics.counter(
244-
VertexAIImageEmbeddings, "cumulativeThrottlingSeconds")
245-
self.throttler = AdaptiveThrottler(
246-
window_ms=1, bucket_ms=1, overload_ratio=2)
247-
248-
@retry.with_exponential_backoff(
249-
num_retries=5, retry_filter=_retry_on_appropriate_gcp_error)
250-
def get_request(
251-
self,
252-
img: Image,
253-
model: MultiModalEmbeddingModel,
254-
throttle_delay_secs: int):
255-
while self.throttler.throttle_request(time.time() * _MSEC_TO_SEC):
256-
LOGGER.info(
257-
"Delaying request for %d seconds due to previous failures",
258-
throttle_delay_secs)
259-
time.sleep(throttle_delay_secs)
260-
self.throttled_secs.inc(throttle_delay_secs)
261-
262-
try:
263-
req_time = time.time()
264-
prediction = model.get_embeddings(image=img, dimension=self.dimension)
265-
self.throttler.successful_request(req_time * _MSEC_TO_SEC)
266-
return prediction
267-
except TooManyRequests as e:
268-
LOGGER.warning("request was limited by the service with code %i", e.code)
269-
raise
270-
except Exception as e:
271-
LOGGER.error("unexpected exception raised as part of request, got %s", e)
272-
raise
273-
274-
def run_inference(
204+
super().__init__(
205+
namespace='VertexAIImageEmbeddingHandler',
206+
retry_filter=_retry_on_appropriate_gcp_error,
207+
**kwargs)
208+
209+
def request(
275210
self,
276-
batch: Sequence[Image],
211+
imgs: Sequence[Image],
277212
model: MultiModalEmbeddingModel,
278-
inference_args: Optional[dict[str, Any]] = None,
279-
) -> Iterable:
213+
inference_args: Optional[dict[str, Any]] = None):
280214
embeddings = []
281-
# Maximum request size for muli-model embedding models is 1.
282-
for img in batch:
283-
embedding_response = self.get_request(img, model, throttle_delay_secs=5)
284-
embeddings.append(embedding_response.image_embedding)
215+
# Max request size for multi-modal embedding models is 1
216+
for img in imgs:
217+
prediction = model.get_embeddings(image=img, dimension=self.dimension)
218+
embeddings.append(prediction.image_embedding)
285219
return embeddings
286220

287-
def load_model(self):
221+
def create_client(self):
288222
model = MultiModalEmbeddingModel.from_pretrained(self.model_name)
289223
return model
290224

@@ -327,6 +261,7 @@ def __init__(
327261
self.project = project
328262
self.location = location
329263
self.credentials = credentials
264+
self.kwargs = kwargs
330265
if dimension is not None and dimension not in (128, 256, 512, 1408):
331266
raise ValueError(
332267
"dimension argument must be one of 128, 256, 512, or 1408")
@@ -340,7 +275,7 @@ def get_model_handler(self) -> ModelHandler:
340275
project=self.project,
341276
location=self.location,
342277
credentials=self.credentials,
343-
)
278+
**self.kwargs)
344279

345280
def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform:
346281
return RunInference(

0 commit comments

Comments
 (0)