Skip to content

Commit d68aae4

Browse files
Merge pull request #5 from RadarML/dev
Evaluation Pipeline
2 parents d464d2d + 5f5f9e7 commit d68aae4

File tree

15 files changed

+4090
-766
lines changed

15 files changed

+4090
-766
lines changed

grt/config/default.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ datamodule:
3636
train:
3737
_target_: nrdk.config.expand
3838
path: ${meta.dataset}
39+
test:
40+
_target_: nrdk.config.expand
41+
path: ${meta.dataset}
3942
batch_size: 32
4043
samples: 8
4144
num_workers: 16

grt/evaluate.py

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
"""GRT reference implementation evaluation script."""
2+
3+
import os
4+
import re
5+
from functools import partial
6+
from queue import Queue
7+
from typing import Any, Callable
8+
9+
import hydra
10+
import numpy as np
11+
import torch
12+
import tyro
13+
from omegaconf import DictConfig
14+
from roverd.channels.utils import Prefetch
15+
from roverd.sensors import DynamicSensor
16+
from tqdm import tqdm
17+
18+
from nrdk.framework import Result
19+
20+
21+
def _get_dataloaders(
22+
cfg: DictConfig, data_root: str, transforms: Any,
23+
traces: list[str] | None = None, filter: str | None = None,
24+
sample: int | None = None
25+
) -> dict[str, Callable[[], torch.utils.data.DataLoader]]:
26+
datamodule = hydra.utils.instantiate(
27+
cfg["datamodule"], transforms=transforms)
28+
29+
if traces is None and filter is None and sample is not None:
30+
return {"sample": lambda: datamodule.test_dataloader()}
31+
else:
32+
dataset_constructor = hydra.utils.instantiate(
33+
cfg["datamodule"]["dataset"])
34+
if traces is None:
35+
traces = [
36+
os.path.relpath(t, cfg["meta"]["dataset"])
37+
for t in hydra.utils.instantiate(
38+
cfg["datamodule"]["traces"]["test"])]
39+
if filter is not None:
40+
traces = [t for t in traces if re.match(filter, t)]
41+
42+
def construct(t: str) -> torch.utils.data.DataLoader:
43+
dataset = dataset_constructor(paths=[t])
44+
return datamodule.dataloader(dataset, mode="test")
45+
46+
return {
47+
t: partial(construct, os.path.join(data_root, t)) for t in traces}
48+
49+
50+
def evaluate(
51+
path: str, /, sample: int | None = None,
52+
traces: list[str] | None = None, filter: str | None = None,
53+
data_root: str | None = None,
54+
device: str = "cuda:0",
55+
batch: int = 32, workers: int = 32, prefetch: int = 2
56+
) -> None:
57+
"""Evaluate a trained model.
58+
59+
Supports three evaluation modes, in order of precedence:
60+
61+
1. Enumerated traces: evaluate all traces specified by `--trace`, relative
62+
to the `--data-root`.
63+
2. Filtered evaluation: evaluate all traces in the configuration
64+
(`datamodule/traces/test`) that match the provided `--filter` regex.
65+
3. Sample evaluation: evaluate a pseudo-random `--sample` taken from
66+
the test set specified in the configuration.
67+
68+
If none of `--trace`, `--filter`, or `--sample` are provided, defaults to
69+
evaluating all traces specified in the configuration.
70+
71+
!!! tip
72+
73+
See [`Result`][nrdk.framework.Result] for details about the expected
74+
structure of the results directory.
75+
76+
!!! warning
77+
78+
Only supports using a single GPU; if multiple GPUs are available,
79+
use parallel evaluation instead.
80+
81+
Args:
82+
path: path to results directory.
83+
sample: number of samples to evaluate.
84+
traces: explicit list of traces to evaluate.
85+
filter: evaluate all traces matching this regex.
86+
data_root: root dataset directory; if `None`, use the path specified
87+
in `meta/dataset` in the config.
88+
device: device to use for evaluation.
89+
batch: batch size.
90+
workers: number of workers for data loading.
91+
prefetch: number of batches to prefetch per worker.
92+
"""
93+
result = Result(path)
94+
cfg = result.config()
95+
if sample is not None:
96+
cfg["datamodule"]["subsample"]["test"] = sample
97+
98+
if data_root is None:
99+
data_root = cfg["meta"]["dataset"]
100+
if data_root is None:
101+
raise ValueError(
102+
"`--data_root` must be specified if `meta/dataset` is not set "
103+
"in the config.")
104+
else:
105+
cfg["meta"]["dataset"] = data_root
106+
107+
cfg["datamodule"]["batch_size"] = batch
108+
cfg["datamodule"]["num_workers"] = workers
109+
cfg["datamodule"]["prefetch_factor"] = prefetch
110+
111+
transforms = hydra.utils.instantiate(cfg["transforms"])
112+
lightningmodule = hydra.utils.instantiate(
113+
cfg["lightningmodule"], transforms=transforms).to(device)
114+
lightningmodule.load_weights(result.best)
115+
116+
dataloaders = _get_dataloaders(
117+
cfg, data_root, transforms,
118+
traces=traces, filter=filter, sample=sample)
119+
120+
def collect_metadata(y_true):
121+
return {
122+
f"meta/{k}/ts": getattr(v, "timestamps")
123+
for k, v in y_true.items()
124+
}
125+
126+
for trace, dl_constructor in dataloaders.items():
127+
dataloader = dl_constructor()
128+
eval_stream = tqdm(
129+
Prefetch(lightningmodule.evaluate(
130+
dataloader, metadata=collect_metadata, device=device)),
131+
total=len(dataloader), desc=trace)
132+
133+
output_container = DynamicSensor(
134+
os.path.join(result.path, "eval", trace),
135+
create=True, exist_ok=True)
136+
metrics = []
137+
outputs = {}
138+
for batch_metrics, vis in eval_stream:
139+
if len(outputs) == 0:
140+
for k, v in vis.items():
141+
outputs[k] = Queue()
142+
output_container.create(
143+
k.split("/")[-1], meta={
144+
"format": "lzmaf",
145+
"type": f"{v.dtype.kind}{v.dtype.itemsize}",
146+
"shape": v.shape[1:],
147+
"desc": f"eval_render:{k}"
148+
}
149+
).consume(outputs[k], thread=True)
150+
151+
for k, v in vis.items():
152+
for sample in v:
153+
outputs[k].put(sample)
154+
metrics.append(batch_metrics)
155+
156+
for q in outputs.values():
157+
q.put(None)
158+
159+
metrics = {
160+
k: np.concatenate([m[k] for m in metrics], axis=0)
161+
for k in metrics[0]}
162+
np.savez_compressed(
163+
os.path.join(result.path, "eval", trace, "metrics.npz"),
164+
**metrics, allow_pickle=False)
165+
166+
output_container.create("ts", meta={
167+
"format": "raw", "type": "f8", "shape": (),
168+
"desc": "reference timestamps"}
169+
).write(metrics["meta/spectrum/ts"])
170+
171+
172+
if __name__ == "__main__":
173+
tyro.cli(evaluate)

grt/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""GRT Reference implementation training script."""
1+
"""GRT reference implementation training script."""
22

33
import logging
44
import os

0 commit comments

Comments
 (0)