-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patheval.py
More file actions
52 lines (43 loc) · 1.92 KB
/
eval.py
File metadata and controls
52 lines (43 loc) · 1.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import hydra
import torch
from omegaconf import DictConfig
from src.data import build_dataloaders
from src.modules import GaussianLogLikelihoodLoss
from src.models import MLPRegressor
from src.utils import expected_calibration_error, get_device, mae, nll, rmse, set_seed
@hydra.main(config_path="configs", config_name="config", version_base=None)
def main(cfg: DictConfig) -> None:
set_seed(cfg.seed)
device = get_device(cfg.get("device", "auto"))
_, _, test_loader = build_dataloaders(cfg.hyperparameters.batch_size, seed=cfg.seed)
model = MLPRegressor(
input_dim=cfg.model.get("input_dim", 1),
hidden_sizes=cfg.model.hidden_sizes,
activation=cfg.model.activation,
dropout=cfg.model.get("dropout", 0.0),
).to(device)
if cfg.get("checkpoint"):
state = torch.load(cfg.checkpoint, map_location=device)
model.load_state_dict(state["model"])
criterion = GaussianLogLikelihoodLoss()
model.eval()
metrics = {"mse": 0.0, "mae": 0.0, "rmse": 0.0, "nll": 0.0, "ece": 0.0}
count = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
mean, variance = model(data)
loss = criterion(mean, target, variance=variance, interpolate=False)
batch_size = len(data)
metrics["mse"] += torch.mean((mean - target) ** 2).item() * batch_size
metrics["mae"] += mae(mean, target) * batch_size
metrics["rmse"] += rmse(mean, target) * batch_size
metrics["nll"] += nll(mean, target, variance) * batch_size
metrics["ece"] += expected_calibration_error(mean, variance, target) * batch_size
count += batch_size
print(f"Loss: {loss.item():.4f} beta={criterion.beta:.4f}")
for k in metrics:
metrics[k] /= max(count, 1)
print("Test metrics:", metrics)
if __name__ == "__main__":
main()