1414# limitations under the License.
1515
1616"""This streaming pipeline performs image classification using an open-source
17- PyTorch EfficientNet-B0 model optimized for T4 GPUs. It reads image URIs from Pub/Sub,
18- decodes and preprocesses them in parallel, and runs inference with adaptive batch sizing for optimal GPU utilization.
19- The pipeline ensures exactly-once semantics via stateful deduplication and idempotent BigQuery writes,
20- allowing stable and reproducible performance measurements under continuous load.
17+ PyTorch EfficientNet-B0 model optimized for T4 GPUs.
18+ It reads image URIs from Pub/Sub, decodes and preprocesses them in parallel,
19+ and runs inference with adaptive batch sizing for optimal GPU utilization.
20+ The pipeline ensures exactly-once semantics via stateful deduplication and
21+ idempotent BigQuery writes, allowing stable and reproducible performance
22+ measurements under continuous load.
2123Resources like Pub/Sub topic/subscription cleanup is handled programmatically.
2224"""
2325
3638
3739import apache_beam as beam
3840from apache_beam .coders import BytesCoder
41+ from apache_beam .io .filesystems import FileSystems
3942from apache_beam .ml .inference .base import KeyedModelHandler
4043from apache_beam .ml .inference .base import PredictionResult
4144from apache_beam .ml .inference .base import RunInference
4245from apache_beam .ml .inference .pytorch_inference import PytorchModelHandlerTensor
43- from apache_beam .options .pipeline_options import PipelineOptions , SetupOptions , StandardOptions
46+ from apache_beam .options .pipeline_options import PipelineOptions
47+ from apache_beam .options .pipeline_options import SetupOptions
48+ from apache_beam .options .pipeline_options import StandardOptions
4449from apache_beam .runners .runner import PipelineResult
4550from apache_beam .transforms import userstate
4651from apache_beam .transforms import window
4752
48- import PIL .Image as PILImage
4953from google .cloud import pubsub_v1
50- from apache_beam . io . filesystems import FileSystems
54+ import PIL . Image as PILImage
5155
5256# ============ Utility & Preprocessing ============
5357
@@ -102,14 +106,15 @@ def process(self, element):
102106
103107
104108class MakeKeyDoFn (beam .DoFn ):
105- """Produce (image_id, payload) where image_id is stable for dedup & BQ insertId."""
109+ """Produce (image_id, payload) stable for dedup & BQ insertId."""
106110 def __init__ (self , input_mode : str ):
107111 self .input_mode = input_mode
108112
109113 def process (self , element : str | bytes ):
110- # Input can be raw bytes from Pub/Sub or a GCS URI string, depending on mode.
114+ # Input can be raw bytes from Pub/Sub or a GCS URI string, depends on mode
111115 if self .input_mode == "bytes" :
112- # element is bytes message, assume it includes {"image_id": "...", "bytes": base64?} or just raw bytes.
116+ # element is bytes message, assume it includes
117+ # {"image_id": "...", "bytes": base64?} or just raw bytes.
113118 import hashlib
114119 b = element if isinstance (element , (bytes , bytearray )) else bytes (element )
115120 image_id = hashlib .sha1 (b ).hexdigest ()
@@ -127,7 +132,6 @@ class DedupDoFn(beam.DoFn):
127132 seen = userstate .ReadModifyWriteStateSpec ('seen' , BytesCoder ())
128133
129134 def process (self , element , seen = beam .DoFn .StateParam (seen )):
130- key , payload = element
131135 if seen .read () == b'1' :
132136 return
133137 seen .write (b'1' )
@@ -176,7 +180,9 @@ def process(self, kv: Tuple[str, PredictionResult]):
176180 logits = logits .unsqueeze (0 )
177181
178182 probs = F .softmax (logits , dim = - 1 ) # [B, C]
179- values , indices = torch .topk (probs , k = min (self .top_k , probs .shape [- 1 ]), dim = - 1 )
183+ values , indices = torch .topk (
184+ probs , k = min (self .top_k , probs .shape [- 1 ]), dim = - 1
185+ )
180186
181187 topk = [{
182188 "class_id" : int (idx .item ()), "score" : float (val .item ())
@@ -334,7 +340,7 @@ def pick_batch_size(arg: str) -> Optional[int]:
334340
335341
336342def run_load_pipeline (known_args , pipeline_args ):
337- """Reads GCS file with URIs and publishes them to Pub/Sub (for streaming mode )."""
343+ """Reads GCS file with URIs and publishes them to Pub/Sub (for streaming)."""
338344 # enforce smaller/CPU-only defaults for feeder
339345 override_or_add (pipeline_args , '--device' , 'CPU' )
340346 override_or_add (pipeline_args , '--num_workers' , '5' )
@@ -362,7 +368,10 @@ def run_load_pipeline(known_args, pipeline_args):
362368 lines
363369 | 'ToBytes' >> beam .Map (lambda line : line .encode ('utf-8' ))
364370 |
365- 'PublishToPubSub' >> beam .io .WriteToPubSub (topic = known_args .pubsub_topic ))
371+ 'PublishToPubSub' >> beam .io .WriteToPubSub (
372+ topic = known_args .pubsub_topic
373+ )
374+ )
366375 return pipeline .run ()
367376
368377
@@ -378,8 +387,8 @@ def run(
378387 topic_path = known_args .pubsub_topic ,
379388 subscription_path = known_args .pubsub_subscription )
380389
381- # If streaming -> start feeder thread that reads URIs from GCS and fills Pub/Sub.
382390 if known_args .mode == 'streaming' :
391+ # Start feeder thread that reads URIs from GCS and fills Pub/Sub.
383392 threading .Thread (
384393 target = lambda :
385394 (time .sleep (900 ), run_load_pipeline (known_args , pipeline_args )),
@@ -398,7 +407,6 @@ def run(
398407 # Device
399408 device = 'GPU' if known_args .device .upper () == 'GPU' else 'CPU'
400409
401- model = None
402410 bs_ok = None
403411 last_err = None
404412 for bs in tried :
@@ -431,15 +439,15 @@ def run(
431439 "Falling back to batch_size=8 due to previous errors: %s" , last_err )
432440 bs_ok = 8
433441 model_handler = PytorchModelHandlerTensor (
434- model_class = lambda : create_timm_model (known_args .pretrained_model_name ),
442+ model_class = lambda : create_timm_model (
443+ known_args .pretrained_model_name
444+ ),
435445 model_params = {},
436446 state_dict_path = known_args .model_state_dict_path ,
437447 device = device ,
438448 inference_batch_size = bs_ok ,
439449 )
440450
441- tokenizer = None
442-
443451 pipeline = test_pipeline or beam .Pipeline (options = pipeline_options )
444452
445453 if known_args .mode == 'batch' :
@@ -491,13 +499,13 @@ def run(
491499 model_name = known_args .pretrained_model_name )))
492500
493501 if known_args .publish_to_big_query == 'true' :
494- # Schema: image_id:STRING, model_name:STRING, topk:STRING(JSON), infer_ts_ms:INT64
495502 _ = (
496503 results
497504 | 'WriteToBigQuery' >> beam .io .WriteToBigQuery (
498505 known_args .output_table ,
499506 schema =
500- 'image_id:STRING, model_name:STRING, topk:STRING, infer_ts_ms:INT64' ,
507+ 'image_id:STRING, model_name:STRING, topk:STRING, '
508+ 'infer_ts_ms:INT64' ,
501509 write_disposition = beam .io .BigQueryDisposition .WRITE_APPEND ,
502510 create_disposition = beam .io .BigQueryDisposition .CREATE_IF_NEEDED ,
503511 method = beam .io .WriteToBigQuery .Method .STREAMING_INSERTS ))
0 commit comments