Skip to content

Commit f64c764

Browse files
HannaMaofacebook-github-bot
authored andcommitted
Make benchmark tool works for py config files for eval task
Summary: Make benchmark tool works for py config files for eval task Reviewed By: ppwwyyxx Differential Revision: D31289374 fbshipit-source-id: 0d758cf3fb7aeccc18eb383e4236bd4203289ef4
1 parent 2df9c7d commit f64c764

File tree

1 file changed

+15
-6
lines changed

1 file changed

+15
-6
lines changed

tools/benchmark.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -133,14 +133,23 @@ def f():
133133
@torch.no_grad()
134134
def benchmark_eval(args):
135135
cfg = setup(args)
136-
model = build_model(cfg)
136+
if args.config_file.endswith(".yaml"):
137+
model = build_model(cfg)
138+
DetectionCheckpointer(model).load(cfg.MODEL.WEIGHTS)
139+
140+
cfg.defrost()
141+
cfg.DATALOADER.NUM_WORKERS = 0
142+
data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0])
143+
else:
144+
model = instantiate(cfg.model)
145+
model.to(cfg.train.device)
146+
DetectionCheckpointer(model).load(cfg.train.init_checkpoint)
147+
148+
cfg.dataloader.num_workers = 0
149+
data_loader = instantiate(cfg.dataloader.test)
150+
137151
model.eval()
138152
logger.info("Model:\n{}".format(model))
139-
DetectionCheckpointer(model).load(cfg.MODEL.WEIGHTS)
140-
141-
cfg.defrost()
142-
cfg.DATALOADER.NUM_WORKERS = 0
143-
data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0])
144153
dummy_data = DatasetFromList(list(itertools.islice(data_loader, 100)), copy=False)
145154

146155
def f():

0 commit comments

Comments
 (0)