2020# to install Vertex AI Python SDK.
2121
2222import logging
23- import time
24- from collections .abc import Iterable
2523from collections .abc import Sequence
2624from typing import Any
2725from typing import Optional
3230
3331import apache_beam as beam
3432import vertexai
35- from apache_beam .io .components .adaptive_throttler import AdaptiveThrottler
36- from apache_beam .metrics .metric import Metrics
3733from apache_beam .ml .inference .base import ModelHandler
34+ from apache_beam .ml .inference .base import RemoteModelHandler
3835from apache_beam .ml .inference .base import RunInference
3936from apache_beam .ml .transforms .base import EmbeddingsManager
4037from apache_beam .ml .transforms .base import _ImageEmbeddingHandler
4138from apache_beam .ml .transforms .base import _TextEmbeddingHandler
42- from apache_beam .utils import retry
4339from vertexai .language_models import TextEmbeddingInput
4440from vertexai .language_models import TextEmbeddingModel
4541from vertexai .vision_models import Image
@@ -80,7 +76,7 @@ def _retry_on_appropriate_gcp_error(exception):
8076 return isinstance (exception , (TooManyRequests , ServerError ))
8177
8278
83- class _VertexAITextEmbeddingHandler (ModelHandler ):
79+ class _VertexAITextEmbeddingHandler (RemoteModelHandler ):
8480 """
8581 Note: Intended for internal use and guarantees no backwards compatibility.
8682 """
@@ -92,7 +88,7 @@ def __init__(
9288 project : Optional [str ] = None ,
9389 location : Optional [str ] = None ,
9490 credentials : Optional [Credentials ] = None ,
95- ):
91+ ** kwargs ):
9692 vertexai .init (project = project , location = location , credentials = credentials )
9793 self .model_name = model_name
9894 if task_type not in TASK_TYPE_INPUTS :
@@ -101,47 +97,16 @@ def __init__(
10197 self .task_type = task_type
10298 self .title = title
10399
104- # Configure AdaptiveThrottler and throttling metrics for client-side
105- # throttling behavior.
106- # See https://docs.google.com/document/d/1ePorJGZnLbNCmLD9mR7iFYOdPsyDA1rDnTpYnbdrzSU/edit?usp=sharing
107- # for more details.
108- self .throttled_secs = Metrics .counter (
109- VertexAIImageEmbeddings , "cumulativeThrottlingSeconds" )
110- self .throttler = AdaptiveThrottler (
111- window_ms = 1 , bucket_ms = 1 , overload_ratio = 2 )
112-
113- @retry .with_exponential_backoff (
114- num_retries = 5 , retry_filter = _retry_on_appropriate_gcp_error )
115- def get_request (
116- self ,
117- text_batch : Sequence [TextEmbeddingInput ],
118- model : TextEmbeddingModel ,
119- throttle_delay_secs : int ):
120- while self .throttler .throttle_request (time .time () * _MSEC_TO_SEC ):
121- LOGGER .info (
122- "Delaying request for %d seconds due to previous failures" ,
123- throttle_delay_secs )
124- time .sleep (throttle_delay_secs )
125- self .throttled_secs .inc (throttle_delay_secs )
126-
127- try :
128- req_time = time .time ()
129- prediction = model .get_embeddings (list (text_batch ))
130- self .throttler .successful_request (req_time * _MSEC_TO_SEC )
131- return prediction
132- except TooManyRequests as e :
133- LOGGER .warning ("request was limited by the service with code %i" , e .code )
134- raise
135- except Exception as e :
136- LOGGER .error ("unexpected exception raised as part of request, got %s" , e )
137- raise
138-
139- def run_inference (
100+ super ().__init__ (
101+ namespace = 'VertexAITextEmbeddingHandler' ,
102+ retry_filter = _retry_on_appropriate_gcp_error ,
103+ ** kwargs )
104+
105+ def request (
140106 self ,
141107 batch : Sequence [str ],
142- model : Any ,
143- inference_args : Optional [dict [str , Any ]] = None ,
144- ) -> Iterable :
108+ model : TextEmbeddingModel ,
109+ inference_args : Optional [dict [str , Any ]] = None ):
145110 embeddings = []
146111 batch_size = _BATCH_SIZE
147112 for i in range (0 , len (batch ), batch_size ):
@@ -151,12 +116,11 @@ def run_inference(
151116 text = text , title = self .title , task_type = self .task_type )
152117 for text in text_batch_strs
153118 ]
154- embeddings_batch = self .get_request (
155- text_batch = text_batch , model = model , throttle_delay_secs = 5 )
119+ embeddings_batch = model .get_embeddings (list (text_batch ))
156120 embeddings .extend ([el .values for el in embeddings_batch ])
157121 return embeddings
158122
159- def load_model (self ):
123+ def create_client (self ) -> TextEmbeddingModel :
160124 model = TextEmbeddingModel .from_pretrained (self .model_name )
161125 return model
162126
@@ -205,6 +169,7 @@ def __init__(
205169 self .credentials = credentials
206170 self .title = title
207171 self .task_type = task_type
172+ self .kwargs = kwargs
208173 super ().__init__ (columns = columns , ** kwargs )
209174
210175 def get_model_handler (self ) -> ModelHandler :
@@ -215,76 +180,45 @@ def get_model_handler(self) -> ModelHandler:
215180 credentials = self .credentials ,
216181 title = self .title ,
217182 task_type = self .task_type ,
218- )
183+ ** self . kwargs )
219184
220185 def get_ptransform_for_processing (self , ** kwargs ) -> beam .PTransform :
221186 return RunInference (
222187 model_handler = _TextEmbeddingHandler (self ),
223188 inference_args = self .inference_args )
224189
225190
226- class _VertexAIImageEmbeddingHandler (ModelHandler ):
191+ class _VertexAIImageEmbeddingHandler (RemoteModelHandler ):
227192 def __init__ (
228193 self ,
229194 model_name : str ,
230195 dimension : Optional [int ] = None ,
231196 project : Optional [str ] = None ,
232197 location : Optional [str ] = None ,
233198 credentials : Optional [Credentials ] = None ,
234- ):
199+ ** kwargs ):
235200 vertexai .init (project = project , location = location , credentials = credentials )
236201 self .model_name = model_name
237202 self .dimension = dimension
238203
239- # Configure AdaptiveThrottler and throttling metrics for client-side
240- # throttling behavior.
241- # See https://docs.google.com/document/d/1ePorJGZnLbNCmLD9mR7iFYOdPsyDA1rDnTpYnbdrzSU/edit?usp=sharing
242- # for more details.
243- self .throttled_secs = Metrics .counter (
244- VertexAIImageEmbeddings , "cumulativeThrottlingSeconds" )
245- self .throttler = AdaptiveThrottler (
246- window_ms = 1 , bucket_ms = 1 , overload_ratio = 2 )
247-
248- @retry .with_exponential_backoff (
249- num_retries = 5 , retry_filter = _retry_on_appropriate_gcp_error )
250- def get_request (
251- self ,
252- img : Image ,
253- model : MultiModalEmbeddingModel ,
254- throttle_delay_secs : int ):
255- while self .throttler .throttle_request (time .time () * _MSEC_TO_SEC ):
256- LOGGER .info (
257- "Delaying request for %d seconds due to previous failures" ,
258- throttle_delay_secs )
259- time .sleep (throttle_delay_secs )
260- self .throttled_secs .inc (throttle_delay_secs )
261-
262- try :
263- req_time = time .time ()
264- prediction = model .get_embeddings (image = img , dimension = self .dimension )
265- self .throttler .successful_request (req_time * _MSEC_TO_SEC )
266- return prediction
267- except TooManyRequests as e :
268- LOGGER .warning ("request was limited by the service with code %i" , e .code )
269- raise
270- except Exception as e :
271- LOGGER .error ("unexpected exception raised as part of request, got %s" , e )
272- raise
273-
274- def run_inference (
204+ super ().__init__ (
205+ namespace = 'VertexAIImageEmbeddingHandler' ,
206+ retry_filter = _retry_on_appropriate_gcp_error ,
207+ ** kwargs )
208+
209+ def request (
275210 self ,
276- batch : Sequence [Image ],
211+ imgs : Sequence [Image ],
277212 model : MultiModalEmbeddingModel ,
278- inference_args : Optional [dict [str , Any ]] = None ,
279- ) -> Iterable :
213+ inference_args : Optional [dict [str , Any ]] = None ):
280214 embeddings = []
281- # Maximum request size for muli-model embedding models is 1.
282- for img in batch :
283- embedding_response = self . get_request ( img , model , throttle_delay_secs = 5 )
284- embeddings .append (embedding_response .image_embedding )
215+ # Max request size for multi-modal embedding models is 1
216+ for img in imgs :
217+ prediction = model . get_embeddings ( image = img , dimension = self . dimension )
218+ embeddings .append (prediction .image_embedding )
285219 return embeddings
286220
287- def load_model (self ):
221+ def create_client (self ):
288222 model = MultiModalEmbeddingModel .from_pretrained (self .model_name )
289223 return model
290224
@@ -327,6 +261,7 @@ def __init__(
327261 self .project = project
328262 self .location = location
329263 self .credentials = credentials
264+ self .kwargs = kwargs
330265 if dimension is not None and dimension not in (128 , 256 , 512 , 1408 ):
331266 raise ValueError (
332267 "dimension argument must be one of 128, 256, 512, or 1408" )
@@ -340,7 +275,7 @@ def get_model_handler(self) -> ModelHandler:
340275 project = self .project ,
341276 location = self .location ,
342277 credentials = self .credentials ,
343- )
278+ ** self . kwargs )
344279
345280 def get_ptransform_for_processing (self , ** kwargs ) -> beam .PTransform :
346281 return RunInference (
0 commit comments