Skip to content

Commit 4ce2406

Browse files
committed
Fix lint
1 parent 417a66f commit 4ce2406

File tree

1 file changed

+29
-21
lines changed

1 file changed

+29
-21
lines changed

sdks/python/apache_beam/examples/inference/pytorch_imagenet_rightfit.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414
# limitations under the License.
1515

1616
"""This streaming pipeline performs image classification using an open-source
17-
PyTorch EfficientNet-B0 model optimized for T4 GPUs. It reads image URIs from Pub/Sub,
18-
decodes and preprocesses them in parallel, and runs inference with adaptive batch sizing for optimal GPU utilization.
19-
The pipeline ensures exactly-once semantics via stateful deduplication and idempotent BigQuery writes,
20-
allowing stable and reproducible performance measurements under continuous load.
17+
PyTorch EfficientNet-B0 model optimized for T4 GPUs.
18+
It reads image URIs from Pub/Sub, decodes and preprocesses them in parallel,
19+
and runs inference with adaptive batch sizing for optimal GPU utilization.
20+
The pipeline ensures exactly-once semantics via stateful deduplication and
21+
idempotent BigQuery writes, allowing stable and reproducible performance
22+
measurements under continuous load.
2123
Resources like Pub/Sub topic/subscription cleanup is handled programmatically.
2224
"""
2325

@@ -36,18 +38,20 @@
3638

3739
import apache_beam as beam
3840
from apache_beam.coders import BytesCoder
41+
from apache_beam.io.filesystems import FileSystems
3942
from apache_beam.ml.inference.base import KeyedModelHandler
4043
from apache_beam.ml.inference.base import PredictionResult
4144
from apache_beam.ml.inference.base import RunInference
4245
from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor
43-
from apache_beam.options.pipeline_options import PipelineOptions, SetupOptions, StandardOptions
46+
from apache_beam.options.pipeline_options import PipelineOptions
47+
from apache_beam.options.pipeline_options import SetupOptions
48+
from apache_beam.options.pipeline_options import StandardOptions
4449
from apache_beam.runners.runner import PipelineResult
4550
from apache_beam.transforms import userstate
4651
from apache_beam.transforms import window
4752

48-
import PIL.Image as PILImage
4953
from google.cloud import pubsub_v1
50-
from apache_beam.io.filesystems import FileSystems
54+
import PIL.Image as PILImage
5155

5256
# ============ Utility & Preprocessing ============
5357

@@ -102,14 +106,15 @@ def process(self, element):
102106

103107

104108
class MakeKeyDoFn(beam.DoFn):
105-
"""Produce (image_id, payload) where image_id is stable for dedup & BQ insertId."""
109+
"""Produce (image_id, payload) stable for dedup & BQ insertId."""
106110
def __init__(self, input_mode: str):
107111
self.input_mode = input_mode
108112

109113
def process(self, element: str | bytes):
110-
# Input can be raw bytes from Pub/Sub or a GCS URI string, depending on mode.
114+
# Input can be raw bytes from Pub/Sub or a GCS URI string, depends on mode
111115
if self.input_mode == "bytes":
112-
# element is bytes message, assume it includes {"image_id": "...", "bytes": base64?} or just raw bytes.
116+
# element is bytes message, assume it includes
117+
# {"image_id": "...", "bytes": base64?} or just raw bytes.
113118
import hashlib
114119
b = element if isinstance(element, (bytes, bytearray)) else bytes(element)
115120
image_id = hashlib.sha1(b).hexdigest()
@@ -127,7 +132,6 @@ class DedupDoFn(beam.DoFn):
127132
seen = userstate.ReadModifyWriteStateSpec('seen', BytesCoder())
128133

129134
def process(self, element, seen=beam.DoFn.StateParam(seen)):
130-
key, payload = element
131135
if seen.read() == b'1':
132136
return
133137
seen.write(b'1')
@@ -176,7 +180,9 @@ def process(self, kv: Tuple[str, PredictionResult]):
176180
logits = logits.unsqueeze(0)
177181

178182
probs = F.softmax(logits, dim=-1) # [B, C]
179-
values, indices = torch.topk(probs, k=min(self.top_k, probs.shape[-1]), dim=-1)
183+
values, indices = torch.topk(
184+
probs, k=min(self.top_k, probs.shape[-1]), dim=-1
185+
)
180186

181187
topk = [{
182188
"class_id": int(idx.item()), "score": float(val.item())
@@ -334,7 +340,7 @@ def pick_batch_size(arg: str) -> Optional[int]:
334340

335341

336342
def run_load_pipeline(known_args, pipeline_args):
337-
"""Reads GCS file with URIs and publishes them to Pub/Sub (for streaming mode)."""
343+
"""Reads GCS file with URIs and publishes them to Pub/Sub (for streaming)."""
338344
# enforce smaller/CPU-only defaults for feeder
339345
override_or_add(pipeline_args, '--device', 'CPU')
340346
override_or_add(pipeline_args, '--num_workers', '5')
@@ -362,7 +368,10 @@ def run_load_pipeline(known_args, pipeline_args):
362368
lines
363369
| 'ToBytes' >> beam.Map(lambda line: line.encode('utf-8'))
364370
|
365-
'PublishToPubSub' >> beam.io.WriteToPubSub(topic=known_args.pubsub_topic))
371+
'PublishToPubSub' >> beam.io.WriteToPubSub(
372+
topic=known_args.pubsub_topic
373+
)
374+
)
366375
return pipeline.run()
367376

368377

@@ -378,8 +387,8 @@ def run(
378387
topic_path=known_args.pubsub_topic,
379388
subscription_path=known_args.pubsub_subscription)
380389

381-
# If streaming -> start feeder thread that reads URIs from GCS and fills Pub/Sub.
382390
if known_args.mode == 'streaming':
391+
# Start feeder thread that reads URIs from GCS and fills Pub/Sub.
383392
threading.Thread(
384393
target=lambda:
385394
(time.sleep(900), run_load_pipeline(known_args, pipeline_args)),
@@ -398,7 +407,6 @@ def run(
398407
# Device
399408
device = 'GPU' if known_args.device.upper() == 'GPU' else 'CPU'
400409

401-
model = None
402410
bs_ok = None
403411
last_err = None
404412
for bs in tried:
@@ -431,15 +439,15 @@ def run(
431439
"Falling back to batch_size=8 due to previous errors: %s", last_err)
432440
bs_ok = 8
433441
model_handler = PytorchModelHandlerTensor(
434-
model_class=lambda: create_timm_model(known_args.pretrained_model_name),
442+
model_class=lambda: create_timm_model(
443+
known_args.pretrained_model_name
444+
),
435445
model_params={},
436446
state_dict_path=known_args.model_state_dict_path,
437447
device=device,
438448
inference_batch_size=bs_ok,
439449
)
440450

441-
tokenizer = None
442-
443451
pipeline = test_pipeline or beam.Pipeline(options=pipeline_options)
444452

445453
if known_args.mode == 'batch':
@@ -491,13 +499,13 @@ def run(
491499
model_name=known_args.pretrained_model_name)))
492500

493501
if known_args.publish_to_big_query == 'true':
494-
# Schema: image_id:STRING, model_name:STRING, topk:STRING(JSON), infer_ts_ms:INT64
495502
_ = (
496503
results
497504
| 'WriteToBigQuery' >> beam.io.WriteToBigQuery(
498505
known_args.output_table,
499506
schema=
500-
'image_id:STRING, model_name:STRING, topk:STRING, infer_ts_ms:INT64',
507+
'image_id:STRING, model_name:STRING, topk:STRING, '
508+
'infer_ts_ms:INT64',
501509
write_disposition=beam.io.BigQueryDisposition.WRITE_APPEND,
502510
create_disposition=beam.io.BigQueryDisposition.CREATE_IF_NEEDED,
503511
method=beam.io.WriteToBigQuery.Method.STREAMING_INSERTS))

0 commit comments

Comments
 (0)