-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathinfer.py
More file actions
61 lines (42 loc) · 1.51 KB
/
infer.py
File metadata and controls
61 lines (42 loc) · 1.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import argparse
from multiprocessing import Pool
from pathlib import Path
from pickle import dumps
from mmengine import Config
from mmengine.device import get_device
from mmengine.runner import Runner, load_checkpoint
from more_itertools import flatten
from tqdm import tqdm
from mutab.score import TEDS
from mutab.utils import build
def options():
args = argparse.ArgumentParser()
args.add_argument("--config", type=str, required=True)
args.add_argument("--weight", type=str, required=True)
args.add_argument("--split", type=str, required=True)
args.add_argument("--store", type=str, required=True)
return args.parse_args()
def conduct(model, loader: dict, split: str):
loader.dataset.filter_cfg.update(split=split)
pool = Pool()
model = model.eval()
metric = TEDS(prefix="full")
loader = tqdm(Runner.build_dataloader(loader))
result = flatten(map(model.test_step, loader))
result = tuple(result)
result = tuple(pool.map(metric._teds, result))
scores = metric.compute_metrics(result)
return dict(scores=scores, data=result)
def process(config: str, weight: str, split: str, store: str):
# config
config = Config.fromfile(config)
# model
model = build(config.model).to(get_device())
load_checkpoint(model, weight, strict=False)
# infer
result = conduct(model, config.test_dataloader, split)
# store
path = Path(store).expanduser()
path.write_bytes(dumps(result))
if __name__ == "__main__":
process(**vars(options()))