@@ -192,7 +192,7 @@ def process(self, kv: Tuple[str, PredictionResult]):
192192 "image_id" : image_id ,
193193 "model_name" : self .model_name ,
194194 "topk" : json .dumps (topk ),
195- "infer_ts_ms " : now_millis (),
195+ "infer_ms " : now_millis (),
196196 }
197197
198198
@@ -319,7 +319,7 @@ def override_or_add(args, flag, value):
319319# ============ Model factory (timm) ============
320320
321321
322- def create_timm_model (model_name : str , num_classes : int = 1000 ):
322+ def create_timm_m (model_name : str , num_classes : int = 1000 ):
323323 import timm
324324 model = timm .create_model (
325325 model_name , pretrained = True , num_classes = num_classes )
@@ -367,8 +367,7 @@ def run_load_pipeline(known_args, pipeline_args):
367367 _ = (
368368 lines
369369 | 'ToBytes' >> beam .Map (lambda line : line .encode ('utf-8' ))
370- |
371- 'WriteToPubSub' >> beam .io .WriteToPubSub (topic = known_args .pubsub_topic ))
370+ | 'ToPubSub' >> beam .io .WriteToPubSub (topic = known_args .pubsub_topic ))
372371 return pipeline .run ()
373372
374373
@@ -409,7 +408,7 @@ def run(
409408 for bs in tried :
410409 try :
411410 model_handler = PytorchModelHandlerTensor (
412- model_class = lambda : create_timm_model (known_args .pretrained_model_name ),
411+ model_class = lambda : create_timm_m (known_args .pretrained_model_name ),
413412 model_params = {},
414413 state_dict_path = known_args .model_state_dict_path ,
415414 device = device ,
@@ -435,9 +434,7 @@ def run(
435434 "Falling back to batch_size=8 due to previous errors: %s" , last_err )
436435 bs_ok = 8
437436 model_handler = PytorchModelHandlerTensor (
438- model_class = lambda : create_timm_model (
439- known_args .pretrained_model_name
440- ),
437+ model_class = lambda : create_timm_m (known_args .pretrained_model_name ),
441438 model_params = {},
442439 state_dict_path = known_args .model_state_dict_path ,
443440 device = device ,
@@ -500,7 +497,7 @@ def run(
500497 | 'WriteToBigQuery' >> beam .io .WriteToBigQuery (
501498 known_args .output_table ,
502499 schema =
503- 'image_id:STRING, model_name:STRING, topk:STRING, infer_ts_ms :INT64' ,
500+ 'image_id:STRING, model_name:STRING, topk:STRING, infer_ms :INT64' ,
504501 write_disposition = beam .io .BigQueryDisposition .WRITE_APPEND ,
505502 create_disposition = beam .io .BigQueryDisposition .CREATE_IF_NEEDED ,
506503 method = beam .io .WriteToBigQuery .Method .STREAMING_INSERTS ))
0 commit comments