Skip to content

Commit 64187d9

Browse files
committed
Refactoring
1 parent a22d694 commit 64187d9

File tree

4 files changed

+32
-34
lines changed

4 files changed

+32
-34
lines changed

.github/workflows/beam_Inference_Python_Benchmarks_Dataflow.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ jobs:
9393
${{ github.workspace }}/.github/workflows/load-tests-pipeline-options/beam_Inference_Python_Benchmarks_Dataflow_Pytorch_Sentiment_Batch_DistilBert_Base_Uncased.txt
9494
${{ github.workspace }}/.github/workflows/load-tests-pipeline-options/beam_Inference_Python_Benchmarks_Dataflow_VLLM_Gemma_Batch.txt
9595
${{ github.workspace }}/.github/workflows/load-tests-pipeline-options/beam_Inference_Python_Benchmarks_Dataflow_Pytorch_Image_Object_Detection.txt
96-
${{ github.workspace }}/.github/workflows/load-tests-pipeline-options/beam_Inference_Python_Benchmarks_Dataflow_Pytorch_Image_Object_Captioning.txt
96+
${{ github.workspace }}/.github/workflows/load-tests-pipeline-options/beam_Inference_Python_Benchmarks_Dataflow_Pytorch_Image_Captioning.txt
9797
# The env variables are created and populated in the test-arguments-action as "<github.job>_test_arguments_<argument_file_paths_index>"
9898
- name: get current time
9999
run: echo "NOW_UTC=$(date '+%m%d%H%M%S' --utc)" >> $GITHUB_ENV

.github/workflows/load-tests-pipeline-options/beam_Inference_Python_Benchmarks_Dataflow_Pytorch_Image_Object_Captioning.txt renamed to .github/workflows/load-tests-pipeline-options/beam_Inference_Python_Benchmarks_Dataflow_Pytorch_Image_Captioning.txt

File renamed without changes.

sdks/python/apache_beam/examples/inference/pytorch_image_captioning.py

Lines changed: 29 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@
5151
import torch
5252
import PIL.Image as PILImage
5353

54-
5554
# ============ Utility ============
5655

5756

@@ -143,13 +142,13 @@ def process(self, kv: Tuple[str, PredictionResult]):
143142

144143
class BlipCaptionModelHandler(ModelHandler):
145144
def __init__(
146-
self,
147-
model_name: str,
148-
device: str,
149-
batch_size: int,
150-
num_captions: int,
151-
max_new_tokens: int,
152-
num_beams: int):
145+
self,
146+
model_name: str,
147+
device: str,
148+
batch_size: int,
149+
num_captions: int,
150+
max_new_tokens: int,
151+
num_beams: int):
153152
self.model_name = model_name
154153
self.device = device
155154
self.batch_size = batch_size
@@ -214,7 +213,7 @@ def run_inference(
214213
candidates_per_image = []
215214
idx = 0
216215
for _ in range(len(batch)):
217-
candidates_per_image.append(captions_all[idx: idx + self.num_captions])
216+
candidates_per_image.append(captions_all[idx:idx + self.num_captions])
218217
idx += self.num_captions
219218

220219
blip_ms = now_millis() - start
@@ -235,11 +234,11 @@ def get_metrics_namespace(self) -> str:
235234

236235
class ClipRankModelHandler(ModelHandler):
237236
def __init__(
238-
self,
239-
model_name: str,
240-
device: str,
241-
batch_size: int,
242-
score_normalize: bool):
237+
self,
238+
model_name: str,
239+
device: str,
240+
batch_size: int,
241+
score_normalize: bool):
243242
self.model_name = model_name
244243
self.device = device
245244
self.batch_size = batch_size
@@ -363,7 +362,8 @@ def parse_known_args(argv):
363362
parser.add_argument('--num_beams', type=int, default=5)
364363

365364
# CLIP
366-
parser.add_argument('--clip_model_name', default='openai/clip-vit-base-patch32')
365+
parser.add_argument(
366+
'--clip_model_name', default='openai/clip-vit-base-patch32')
367367
parser.add_argument('--clip_batch_size', type=int, default=8)
368368
parser.add_argument(
369369
'--clip_score_normalize', default='false', choices=['true', 'false'])
@@ -415,9 +415,7 @@ def run(
415415
| 'MakeKey' >> beam.ParDo(MakeKeyDoFn())
416416
| 'DistinctByKey' >> beam.Distinct())
417417

418-
images = (
419-
keyed
420-
| 'ReadImageBytes' >> beam.ParDo(ReadImageBytesDoFn()))
418+
images = (keyed | 'ReadImageBytes' >> beam.ParDo(ReadImageBytesDoFn()))
421419

422420
# Stage 1: BLIP candidate generation
423421
blip_out = (
@@ -431,23 +429,24 @@ def run(
431429

432430
results = (
433431
clip_out
434-
| 'PostProcess' >> beam.ParDo(PostProcessDoFn(
435-
blip_name=known_args.blip_model_name,
436-
clip_name=known_args.clip_model_name)))
432+
| 'PostProcess' >> beam.ParDo(
433+
PostProcessDoFn(
434+
blip_name=known_args.blip_model_name,
435+
clip_name=known_args.clip_model_name)))
437436

438437
if known_args.publish_to_big_query == 'true':
439438
_ = (
440439
results
441440
| 'WriteToBigQuery' >> beam.io.WriteToBigQuery(
442-
known_args.output_table,
443-
schema=(
444-
'image_id:STRING, blip_model:STRING, clip_model:STRING, '
445-
'best_caption:STRING, best_score:FLOAT, '
446-
'candidates:STRING, scores:STRING, '
447-
'blip_ms:INT64, clip_ms:INT64, total_ms:INT64, infer_ms:INT64'),
448-
write_disposition=beam.io.BigQueryDisposition.WRITE_APPEND,
449-
create_disposition=beam.io.BigQueryDisposition.CREATE_IF_NEEDED,
450-
method=beam.io.WriteToBigQuery.Method.FILE_LOADS))
441+
known_args.output_table,
442+
schema=(
443+
'image_id:STRING, blip_model:STRING, clip_model:STRING, '
444+
'best_caption:STRING, best_score:FLOAT, '
445+
'candidates:STRING, scores:STRING, '
446+
'blip_ms:INT64, clip_ms:INT64, total_ms:INT64, infer_ms:INT64'),
447+
write_disposition=beam.io.BigQueryDisposition.WRITE_APPEND,
448+
create_disposition=beam.io.BigQueryDisposition.CREATE_IF_NEEDED,
449+
method=beam.io.WriteToBigQuery.Method.FILE_LOADS))
451450

452451
result = pipeline.run()
453452
result.wait_until_finish(duration=1800000) # 30 min

sdks/python/apache_beam/examples/inference/pytorch_image_object_detection.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def _torchvision_detection_inference_fn(
155155
class PostProcessDoFn(beam.DoFn):
156156
"""PredictionResult -> dict row for BQ."""
157157
def __init__(
158-
self, model_name: str, score_threshold: float, max_detections: int):
158+
self, model_name: str, score_threshold: float, max_detections: int):
159159
self.model_name = model_name
160160
self.score_threshold = score_threshold
161161
self.max_detections = max_detections
@@ -329,8 +329,7 @@ def run(
329329

330330
model_handler = PytorchModelHandlerTensor(
331331
model_class=lambda: create_torchvision_detection_model(
332-
known_args.pretrained_model_name
333-
),
332+
known_args.pretrained_model_name),
334333
model_params={},
335334
state_dict_path=known_args.model_state_dict_path,
336335
device=device,

0 commit comments

Comments
 (0)