5151import torch
5252import PIL .Image as PILImage
5353
54-
5554# ============ Utility ============
5655
5756
@@ -143,13 +142,13 @@ def process(self, kv: Tuple[str, PredictionResult]):
143142
144143class BlipCaptionModelHandler (ModelHandler ):
145144 def __init__ (
146- self ,
147- model_name : str ,
148- device : str ,
149- batch_size : int ,
150- num_captions : int ,
151- max_new_tokens : int ,
152- num_beams : int ):
145+ self ,
146+ model_name : str ,
147+ device : str ,
148+ batch_size : int ,
149+ num_captions : int ,
150+ max_new_tokens : int ,
151+ num_beams : int ):
153152 self .model_name = model_name
154153 self .device = device
155154 self .batch_size = batch_size
@@ -214,7 +213,7 @@ def run_inference(
214213 candidates_per_image = []
215214 idx = 0
216215 for _ in range (len (batch )):
217- candidates_per_image .append (captions_all [idx : idx + self .num_captions ])
216+ candidates_per_image .append (captions_all [idx :idx + self .num_captions ])
218217 idx += self .num_captions
219218
220219 blip_ms = now_millis () - start
@@ -235,11 +234,11 @@ def get_metrics_namespace(self) -> str:
235234
236235class ClipRankModelHandler (ModelHandler ):
237236 def __init__ (
238- self ,
239- model_name : str ,
240- device : str ,
241- batch_size : int ,
242- score_normalize : bool ):
237+ self ,
238+ model_name : str ,
239+ device : str ,
240+ batch_size : int ,
241+ score_normalize : bool ):
243242 self .model_name = model_name
244243 self .device = device
245244 self .batch_size = batch_size
@@ -363,7 +362,8 @@ def parse_known_args(argv):
363362 parser .add_argument ('--num_beams' , type = int , default = 5 )
364363
365364 # CLIP
366- parser .add_argument ('--clip_model_name' , default = 'openai/clip-vit-base-patch32' )
365+ parser .add_argument (
366+ '--clip_model_name' , default = 'openai/clip-vit-base-patch32' )
367367 parser .add_argument ('--clip_batch_size' , type = int , default = 8 )
368368 parser .add_argument (
369369 '--clip_score_normalize' , default = 'false' , choices = ['true' , 'false' ])
@@ -415,9 +415,7 @@ def run(
415415 | 'MakeKey' >> beam .ParDo (MakeKeyDoFn ())
416416 | 'DistinctByKey' >> beam .Distinct ())
417417
418- images = (
419- keyed
420- | 'ReadImageBytes' >> beam .ParDo (ReadImageBytesDoFn ()))
418+ images = (keyed | 'ReadImageBytes' >> beam .ParDo (ReadImageBytesDoFn ()))
421419
422420 # Stage 1: BLIP candidate generation
423421 blip_out = (
@@ -431,23 +429,24 @@ def run(
431429
432430 results = (
433431 clip_out
434- | 'PostProcess' >> beam .ParDo (PostProcessDoFn (
435- blip_name = known_args .blip_model_name ,
436- clip_name = known_args .clip_model_name )))
432+ | 'PostProcess' >> beam .ParDo (
433+ PostProcessDoFn (
434+ blip_name = known_args .blip_model_name ,
435+ clip_name = known_args .clip_model_name )))
437436
438437 if known_args .publish_to_big_query == 'true' :
439438 _ = (
440439 results
441440 | 'WriteToBigQuery' >> beam .io .WriteToBigQuery (
442- known_args .output_table ,
443- schema = (
444- 'image_id:STRING, blip_model:STRING, clip_model:STRING, '
445- 'best_caption:STRING, best_score:FLOAT, '
446- 'candidates:STRING, scores:STRING, '
447- 'blip_ms:INT64, clip_ms:INT64, total_ms:INT64, infer_ms:INT64' ),
448- write_disposition = beam .io .BigQueryDisposition .WRITE_APPEND ,
449- create_disposition = beam .io .BigQueryDisposition .CREATE_IF_NEEDED ,
450- method = beam .io .WriteToBigQuery .Method .FILE_LOADS ))
441+ known_args .output_table ,
442+ schema = (
443+ 'image_id:STRING, blip_model:STRING, clip_model:STRING, '
444+ 'best_caption:STRING, best_score:FLOAT, '
445+ 'candidates:STRING, scores:STRING, '
446+ 'blip_ms:INT64, clip_ms:INT64, total_ms:INT64, infer_ms:INT64' ),
447+ write_disposition = beam .io .BigQueryDisposition .WRITE_APPEND ,
448+ create_disposition = beam .io .BigQueryDisposition .CREATE_IF_NEEDED ,
449+ method = beam .io .WriteToBigQuery .Method .FILE_LOADS ))
451450
452451 result = pipeline .run ()
453452 result .wait_until_finish (duration = 1800000 ) # 30 min
0 commit comments