Skip to content

Commit 2aaaea0

Browse files
committed
Fix lint
1 parent cf6f968 commit 2aaaea0

File tree

1 file changed

+6
-9
lines changed

1 file changed

+6
-9
lines changed

sdks/python/apache_beam/examples/inference/pytorch_imagenet_rightfit.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def process(self, kv: Tuple[str, PredictionResult]):
192192
"image_id": image_id,
193193
"model_name": self.model_name,
194194
"topk": json.dumps(topk),
195-
"infer_ts_ms": now_millis(),
195+
"infer_ms": now_millis(),
196196
}
197197

198198

@@ -319,7 +319,7 @@ def override_or_add(args, flag, value):
319319
# ============ Model factory (timm) ============
320320

321321

322-
def create_timm_model(model_name: str, num_classes: int = 1000):
322+
def create_timm_m(model_name: str, num_classes: int = 1000):
323323
import timm
324324
model = timm.create_model(
325325
model_name, pretrained=True, num_classes=num_classes)
@@ -367,8 +367,7 @@ def run_load_pipeline(known_args, pipeline_args):
367367
_ = (
368368
lines
369369
| 'ToBytes' >> beam.Map(lambda line: line.encode('utf-8'))
370-
|
371-
'WriteToPubSub' >> beam.io.WriteToPubSub(topic=known_args.pubsub_topic))
370+
| 'ToPubSub' >> beam.io.WriteToPubSub(topic=known_args.pubsub_topic))
372371
return pipeline.run()
373372

374373

@@ -409,7 +408,7 @@ def run(
409408
for bs in tried:
410409
try:
411410
model_handler = PytorchModelHandlerTensor(
412-
model_class=lambda: create_timm_model(known_args.pretrained_model_name),
411+
model_class=lambda: create_timm_m(known_args.pretrained_model_name),
413412
model_params={},
414413
state_dict_path=known_args.model_state_dict_path,
415414
device=device,
@@ -435,9 +434,7 @@ def run(
435434
"Falling back to batch_size=8 due to previous errors: %s", last_err)
436435
bs_ok = 8
437436
model_handler = PytorchModelHandlerTensor(
438-
model_class=lambda: create_timm_model(
439-
known_args.pretrained_model_name
440-
),
437+
model_class=lambda: create_timm_m(known_args.pretrained_model_name),
441438
model_params={},
442439
state_dict_path=known_args.model_state_dict_path,
443440
device=device,
@@ -500,7 +497,7 @@ def run(
500497
| 'WriteToBigQuery' >> beam.io.WriteToBigQuery(
501498
known_args.output_table,
502499
schema=
503-
'image_id:STRING, model_name:STRING, topk:STRING, infer_ts_ms:INT64',
500+
'image_id:STRING, model_name:STRING, topk:STRING, infer_ms:INT64',
504501
write_disposition=beam.io.BigQueryDisposition.WRITE_APPEND,
505502
create_disposition=beam.io.BigQueryDisposition.CREATE_IF_NEEDED,
506503
method=beam.io.WriteToBigQuery.Method.STREAMING_INSERTS))

0 commit comments

Comments
 (0)