Skip to content

Commit 86b10af

Browse files
committed
feature!: refactor metrics
BREAKING CHANGE
1 parent cc84004 commit 86b10af

File tree

10 files changed

+144
-141
lines changed

10 files changed

+144
-141
lines changed

docs/source/lantern.rst

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ lantern package
1313
:members:
1414
:undoc-members:
1515
:member-order: bysource
16-
.. autoclass:: lantern.Metrics
16+
.. autoclass:: lantern.MetricTable
1717
:members:
1818
:undoc-members:
1919
:member-order: bysource
@@ -29,6 +29,10 @@ lantern package
2929
:members:
3030
:undoc-members:
3131
:member-order: bysource
32+
.. autoclass:: lantern.Tensor
33+
:members:
34+
:undoc-members:
35+
:member-order: bysource
3236

3337
.. autofunction:: lantern.Epochs
3438
.. autofunction:: lantern.module_device
@@ -40,4 +44,5 @@ lantern package
4044
.. autofunction:: lantern.update_cpu_model
4145
.. autofunction:: lantern.step
4246
.. autofunction:: lantern.to_device
43-
.. autofunction:: lantern.worker_init
47+
.. autofunction:: lantern.worker_init_fn
48+
.. autofunction:: lantern.set_learning_rate

lantern/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from lantern.functional_base import FunctionalBase
77
from lantern.tensor import Tensor
88
from lantern.epochs import Epochs
9-
from lantern.metric import ReduceMetric, MapMetric
10-
from lantern.metrics import Metrics
9+
from lantern.metric import Metric, ReduceMetric, MapMetric
10+
from lantern.metric_table import MetricTable
1111
from lantern.module_device import module_device
1212
from lantern.module_compose import ModuleCompose
1313
from lantern.to_device import to_device
@@ -16,7 +16,7 @@
1616
from lantern.requires_grad import requires_grad, requires_nograd
1717
from lantern.set_learning_rate import set_learning_rate
1818
from lantern.set_seeds import set_seeds
19-
from lantern.worker_init import worker_init
19+
from lantern.worker_init_fn import worker_init_fn
2020
from lantern.evaluate import evaluate
2121
from lantern.step import step
2222
from lantern.train import train

lantern/early_stopping.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ class EarlyStopping(FunctionalBase):
1313

1414
class Config:
1515
arbitrary_types_allowed = True
16-
allow_mutation = False
1716

1817
def score(self, value):
1918
if self.best_score is None or value >= self.best_score:

lantern/metric.py

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,59 @@
11
import numpy as np
2+
from abc import ABC, abstractmethod
3+
from lantern import FunctionalBase
4+
from typing import Callable, Any, Optional
25

36

4-
class ReduceMetric:
7+
class Metric(ABC):
8+
@abstractmethod
9+
def update(self):
10+
...
11+
12+
@abstractmethod
13+
def update_(self):
14+
...
15+
16+
@abstractmethod
17+
def compute(self):
18+
...
19+
20+
21+
class ReduceMetric(FunctionalBase, Metric):
22+
reduce_fn: Callable
23+
compute_fn: Callable
24+
state: Optional[Any]
25+
26+
class Config:
27+
allow_mutation = True
28+
529
def __init__(self, reduce_fn, compute_fn=None, initial_state=None):
6-
self.reduce_fn = reduce_fn
7-
if compute_fn is None:
8-
self.compute_fn = lambda x: x
9-
else:
10-
self.compute_fn = compute_fn
11-
self.state = initial_state
12-
13-
def reduce(self, *args, **kwargs):
14-
return ReduceMetric(
15-
reduce_fn=self.reduce_fn,
16-
compute_fn=self.compute_fn,
17-
initial_state=self.reduce_fn(self.state, *args, **kwargs),
30+
super().__init__(
31+
reduce_fn=reduce_fn,
32+
compute_fn=((lambda x: x) if compute_fn is None else compute_fn),
33+
state=initial_state,
1834
)
1935

36+
def update(self, *args, **kwargs):
37+
return self.replace(state=self.reduce_fn(self.state, *args, **kwargs))
38+
39+
def update_(self, *args, **kwargs):
40+
self.state = self.reduce_fn(self.state, *args, **kwargs)
41+
return self
42+
2043
def compute(self):
2144
return self.compute_fn(self.state)
2245

46+
def log(self, tensorboard_logger, tag, name, step=1):
47+
tensorboard_logger.add_scalar(
48+
f"{tag}/{name}",
49+
self.compute(),
50+
step,
51+
)
52+
return self
53+
2354

2455
def MapMetric(map_fn, compute_fn=np.mean):
56+
"""Metric version of `compute_fn(map(map_fn, input))`"""
2557
return ReduceMetric(
2658
reduce_fn=lambda state, *args, **kwargs: state + [map_fn(*args, **kwargs)],
2759
compute_fn=compute_fn,

lantern/metric_table.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import textwrap
2+
import pandas as pd
3+
from lantern import FunctionalBase, Metric
4+
from typing import Dict
5+
6+
7+
class MetricTable(FunctionalBase):
8+
name: str
9+
metrics: Dict[str, Metric]
10+
11+
class Config:
12+
arbitrary_types_allowed = True
13+
14+
def __init__(self, name, metrics):
15+
super().__init__(
16+
name=name,
17+
metrics=metrics,
18+
)
19+
20+
def compute(self):
21+
return {name: metric.compute() for name, metric in self.metrics.items()}
22+
23+
def table(self):
24+
return "\n".join(
25+
[
26+
f"{self.name}:",
27+
textwrap.indent(
28+
(
29+
pd.Series(self.compute()).to_string(
30+
name=True, dtype=False, index=True
31+
)
32+
),
33+
prefix=" ",
34+
),
35+
]
36+
)
37+
38+
def __str__(self):
39+
return self.table()

lantern/metrics.py

Lines changed: 0 additions & 73 deletions
This file was deleted.

lantern/progress_bar.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
from tqdm import tqdm
2+
from typing import Dict, Optional
3+
from lantern.metric import Metric
24

35

4-
def ProgressBar(data_loader, metrics=None):
6+
def ProgressBar(data_loader, name, metrics: Optional[Dict[str, Metric]] = None):
57
"""Simple progress bar with metrics"""
68
if metrics is None:
7-
return tqdm(data_loader, leave=False)
9+
for item in tqdm(data_loader, desc=name, leave=False):
10+
yield item
811
else:
9-
with tqdm(data_loader, desc=metrics.name, leave=False) as tqdm_:
12+
with tqdm(data_loader, desc=name, leave=False) as tqdm_:
1013
for item in tqdm_:
1114
yield item
12-
tqdm_.set_postfix(metrics.compute())
15+
tqdm_.set_postfix(
16+
{name: metric.compute() for name, metric in metrics.items()}
17+
)

lantern/worker_init.py

Lines changed: 0 additions & 10 deletions
This file was deleted.

lantern/worker_init_fn.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from lantern import set_seeds
2+
3+
4+
def worker_init_fn(seed):
5+
def worker_init(worker_id):
6+
set_seeds(seed * 2 ** 16 + worker_id * 2 ** 24)
7+
8+
return worker_init

test/test_mnist.py

Lines changed: 33 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515

1616
def test_mnist():
17+
torch.set_grad_enabled(False)
1718

1819
device = torch.device("cpu")
1920
model = ModuleCompose(
@@ -31,74 +32,71 @@ def test_mnist():
3132
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
3233
)
3334

34-
gradient_dataset = datastream.Dataset.from_subscriptable(
35+
train_dataset = datastream.Dataset.from_subscriptable(
3536
datasets.MNIST("data", train=True, transform=transform, download=True)
3637
)
3738
early_stopping_dataset = datastream.Dataset.from_subscriptable(
3839
datasets.MNIST("data", train=False, transform=transform)
3940
)
4041

41-
gradient_data_loader = (
42-
datastream.Datastream(gradient_dataset).take(16 * 4).data_loader(batch_size=16)
42+
train_data_loader = (
43+
datastream.Datastream(train_dataset).take(16 * 4).data_loader(batch_size=16)
4344
)
4445
early_stopping_data_loader = (
4546
datastream.Datastream(early_stopping_dataset)
4647
.take(16 * 4)
4748
.data_loader(batch_size=16)
4849
)
4950
evaluate_data_loaders = dict(
50-
evaluate_gradient=gradient_data_loader,
51+
evaluate_train=train_data_loader,
5152
evaluate_early_stopping=early_stopping_data_loader,
5253
)
5354

5455
tensorboard_logger = torch.utils.tensorboard.SummaryWriter()
5556
early_stopping = lantern.EarlyStopping(tensorboard_logger=tensorboard_logger)
56-
gradient_metrics = lantern.Metrics(
57-
name="gradient",
58-
tensorboard_logger=tensorboard_logger,
59-
metrics=dict(
60-
loss=lantern.ReduceMetric(lambda state, examples, predictions, loss: loss),
61-
),
57+
train_metrics = dict(
58+
loss=lantern.ReduceMetric(lambda state, loss: loss.item()),
6259
)
6360

6461
for epoch in lantern.Epochs(2):
6562

66-
with lantern.module_train(model):
67-
for examples, targets in lantern.ProgressBar(
68-
gradient_data_loader, metrics=gradient_metrics[["loss"]]
69-
):
63+
for examples, targets in lantern.ProgressBar(
64+
train_data_loader, "train", train_metrics
65+
):
66+
with lantern.module_train(model), torch.enable_grad():
7067
predictions = model(examples)
7168
loss = F.nll_loss(predictions, targets)
7269
loss.backward()
73-
optimizer.step()
74-
optimizer.zero_grad()
70+
optimizer.step()
71+
optimizer.zero_grad()
7572

76-
(
77-
gradient_metrics.update_(
78-
examples, predictions.detach(), loss.detach()
79-
).log_()
80-
)
81-
sleep(0.5)
82-
gradient_metrics.print()
73+
train_metrics["loss"].update_(loss)
74+
sleep(0.5)
75+
76+
for name, metric in train_metrics.items():
77+
metric.log(tensorboard_logger, "train", name, epoch)
78+
79+
print(lantern.MetricTable("train", train_metrics))
8380

8481
evaluate_metrics = {
85-
name: lantern.Metrics(
86-
name=name,
87-
tensorboard_logger=tensorboard_logger,
88-
metrics=dict(
89-
loss=lantern.MapMetric(lambda examples, predictions, loss: loss),
90-
),
82+
name: dict(
83+
loss=lantern.MapMetric(lambda loss: loss.item()),
9184
)
92-
for name in evaluate_data_loaders.keys()
85+
for name in evaluate_data_loaders
9386
}
9487

95-
with lantern.module_eval(model), torch.no_grad():
96-
for name, data_loader in evaluate_data_loaders.items():
97-
for examples, targets in tqdm(data_loader, desc=name, leave=False):
88+
for name, data_loader in evaluate_data_loaders.items():
89+
for examples, targets in tqdm(data_loader, desc=name, leave=False):
90+
with lantern.module_eval(model):
9891
predictions = model(examples)
9992
loss = F.nll_loss(predictions, targets)
100-
evaluate_metrics[name].update_(examples, predictions, loss)
101-
evaluate_metrics[name].log_().print()
93+
94+
evaluate_metrics[name]["loss"].update_(loss)
95+
96+
for metric_name, metric in evaluate_metrics[name].items():
97+
metric.log(tensorboard_logger, name, metric_name, epoch)
98+
99+
print(lantern.MetricTable(name, evaluate_metrics[name]))
102100

103101
early_stopping = early_stopping.score(
104102
-evaluate_metrics["evaluate_early_stopping"]["loss"].compute()

0 commit comments

Comments
 (0)