Skip to content

Commit f7aa22b

Browse files
committed
feature!: chain metrics
BREAKING CHANGE
1 parent 86b10af commit f7aa22b

File tree

5 files changed

+175
-50
lines changed

5 files changed

+175
-50
lines changed

lantern/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from lantern.functional_base import FunctionalBase
77
from lantern.tensor import Tensor
88
from lantern.epochs import Epochs
9-
from lantern.metric import Metric, ReduceMetric, MapMetric
9+
from lantern.metric import Metric
1010
from lantern.metric_table import MetricTable
1111
from lantern.module_device import module_device
1212
from lantern.module_compose import ModuleCompose

lantern/metric.py

Lines changed: 151 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,175 @@
11
import numpy as np
2-
from abc import ABC, abstractmethod
32
from lantern import FunctionalBase
4-
from typing import Callable, Any, Optional
3+
from lantern.functional import star
4+
from typing import Callable, Any, Optional, Dict, List, Union
5+
from pydantic import BaseModel, Extra
56

67

7-
class Metric(ABC):
8-
@abstractmethod
9-
def update(self):
10-
...
8+
class MapMetric(BaseModel):
9+
map_fn_: Optional[Callable[..., Any]]
10+
# map_fn: Optional[Callable] = lambda value: self, value # HACK: why are we getting self?
11+
state: List[Any]
1112

12-
@abstractmethod
13-
def update_(self):
14-
...
13+
class Config:
14+
arbitrary_types_allowed = True
15+
allow_mutation = True
16+
extra = Extra.forbid
17+
18+
def __init__(self, map_fn_=None, state=list()):
19+
super().__init__(
20+
map_fn_=map_fn_,
21+
state=state,
22+
)
23+
24+
def replace(self, **kwargs):
25+
new_dict = self.dict()
26+
new_dict.update(**kwargs)
27+
return type(self)(**new_dict)
28+
29+
def map(self, fn):
30+
# return self.replace(fn=lambda value: fn(self.map_fn_(value)))
31+
# HACK: why doesn't the above work?
32+
if self.map_fn_ is None:
33+
return MapMetric(
34+
map_fn_=fn,
35+
state=self.state,
36+
)
37+
else:
38+
return MapMetric(
39+
map_fn_=lambda *args, **kwargs: fn(self.map_fn_(*args, **kwargs)),
40+
state=self.state,
41+
)
42+
43+
def starmap(self, fn):
44+
return self.map(star(fn))
45+
46+
def reduce(self, fn):
47+
if self.map_fn_ is None:
48+
return ReduceMetric(
49+
map_fn_=lambda *args: args,
50+
reduce_fn=lambda state, args: fn(state, *args),
51+
state=self.state, # TODO: apply function on state...
52+
)
53+
else:
54+
return ReduceMetric(
55+
map_fn_=self.map_fn_,
56+
reduce_fn=fn,
57+
state=self.state,
58+
)
59+
60+
def aggregate(self, fn):
61+
return AggregateMetric(metric=self, aggregate_fn=fn)
62+
63+
def staraggregate(self, fn):
64+
return self.aggregate(star(fn))
65+
66+
def update_(self, *args, **kwargs):
67+
if self.map_fn_ is None:
68+
self.state.append(args)
69+
else:
70+
self.state.append(self.map_fn_(*args, **kwargs))
71+
return self
72+
73+
def update(self, *args, **kwargs):
74+
if self.map_fn_ is None:
75+
return self.replace(state=self.state + ([args[0]] if len(args) == 1 else [args]))
76+
else:
77+
return self.replace(state=self.state + [self.map_fn_(*args, **kwargs)])
1578

16-
@abstractmethod
1779
def compute(self):
18-
...
80+
return self.state
1981

82+
def log(self, tensorboard_logger, tag, step=None):
83+
for name, value in self.compute().items():
84+
tensorboard_logger.add_scalar(
85+
f"{tag}/{name}",
86+
value,
87+
step,
88+
)
89+
return self
90+
91+
92+
Metric = MapMetric
2093

21-
class ReduceMetric(FunctionalBase, Metric):
22-
reduce_fn: Callable
23-
compute_fn: Callable
24-
state: Optional[Any]
94+
95+
class ReduceMetric(BaseModel):
96+
map_fn_: Callable[..., Any]
97+
reduce_fn: Callable[..., Any]
98+
state: Any
2599

26100
class Config:
101+
arbitrary_types_allowed = True
27102
allow_mutation = True
103+
extra = Extra.forbid
28104

29-
def __init__(self, reduce_fn, compute_fn=None, initial_state=None):
30-
super().__init__(
31-
reduce_fn=reduce_fn,
32-
compute_fn=((lambda x: x) if compute_fn is None else compute_fn),
33-
state=initial_state,
34-
)
105+
def replace(self, **kwargs):
106+
new_dict = self.dict()
107+
new_dict.update(**kwargs)
108+
return type(self)(**new_dict)
109+
110+
def update_(self, *args, **kwargs):
111+
self.state = self.reduce_fn(self.state, self.map_fn_(*args, **kwargs))
112+
return self
35113

36114
def update(self, *args, **kwargs):
37-
return self.replace(state=self.reduce_fn(self.state, *args, **kwargs))
115+
return self.replace(
116+
state=self.reduce_fn(self.state, self.map_fn_(*args, **kwargs))
117+
)
118+
119+
def compute(self):
120+
return self.state
121+
122+
def log(self, tensorboard_logger, tag, step=None):
123+
for name, value in self.compute().items():
124+
tensorboard_logger.add_scalar(
125+
f"{tag}/{name}",
126+
value,
127+
step,
128+
)
129+
return self
130+
131+
132+
class AggregateMetric(BaseModel):
133+
metric: Union[MapMetric, ReduceMetric]
134+
aggregate_fn: Callable
135+
136+
class Config:
137+
arbitrary_types_allowed = True
138+
allow_mutation = True
139+
extra = Extra.forbid
140+
141+
def replace(self, **kwargs):
142+
new_dict = self.dict()
143+
new_dict.update(**kwargs)
144+
return type(self)(**new_dict)
145+
146+
def map(self, fn):
147+
return self.replace(
148+
aggregate_fn=lambda state: fn(self.aggregate_fn(state))
149+
)
150+
151+
def starmap(self, fn):
152+
return self.map(star(fn))
38153

39154
def update_(self, *args, **kwargs):
40-
self.state = self.reduce_fn(self.state, *args, **kwargs)
155+
self.metric = self.metric.update(*args, **kwargs)
41156
return self
42157

158+
def update(self, *args, **kwargs):
159+
return self.replace(metric=self.metric.update(*args, **kwargs))
160+
43161
def compute(self):
44-
return self.compute_fn(self.state)
162+
return self.aggregate_fn(self.metric.compute())
45163

46-
def log(self, tensorboard_logger, tag, name, step=1):
47-
tensorboard_logger.add_scalar(
48-
f"{tag}/{name}",
49-
self.compute(),
50-
step,
51-
)
164+
def log(self, tensorboard_logger, tag, step=None):
165+
for name, value in self.compute().items():
166+
tensorboard_logger.add_scalar(
167+
f"{tag}/{name}",
168+
value,
169+
step,
170+
)
52171
return self
53172

54173

55-
def MapMetric(map_fn, compute_fn=np.mean):
56-
"""Metric version of `compute_fn(map(map_fn, input))`"""
57-
return ReduceMetric(
58-
reduce_fn=lambda state, *args, **kwargs: state + [map_fn(*args, **kwargs)],
59-
compute_fn=compute_fn,
60-
initial_state=list(),
61-
)
174+
def test_metric():
175+
pass

lantern/metric_table.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
import textwrap
22
import pandas as pd
3-
from lantern import FunctionalBase, Metric
4-
from typing import Dict
3+
from lantern import FunctionalBase
4+
from typing import Dict, Union, Any
5+
6+
# from wire_damage.tools import MapMetric, ReduceMetric, AggregateMetric
57

68

79
class MetricTable(FunctionalBase):
810
name: str
9-
metrics: Dict[str, Metric]
11+
metrics: Dict[str, Any]
12+
# metrics: Dict[str, Union[MapMetric, ReduceMetric, AggregateMetric]]
1013

1114
class Config:
1215
arbitrary_types_allowed = True
@@ -18,7 +21,11 @@ def __init__(self, name, metrics):
1821
)
1922

2023
def compute(self):
21-
return {name: metric.compute() for name, metric in self.metrics.items()}
24+
return {
25+
metric_name: value
26+
for metrics in self.metrics.values()
27+
for metric_name, value in metrics.compute().items()
28+
}
2229

2330
def table(self):
2431
return "\n".join(

lantern/progress_bar.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,9 @@ def ProgressBar(data_loader, name, metrics: Optional[Dict[str, Metric]] = None):
1313
for item in tqdm_:
1414
yield item
1515
tqdm_.set_postfix(
16-
{name: metric.compute() for name, metric in metrics.items()}
16+
{
17+
name: value
18+
for metrics in metrics.values()
19+
for name, value in metrics.compute().items()
20+
}
1721
)

test/test_mnist.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def test_mnist():
5555
tensorboard_logger = torch.utils.tensorboard.SummaryWriter()
5656
early_stopping = lantern.EarlyStopping(tensorboard_logger=tensorboard_logger)
5757
train_metrics = dict(
58-
loss=lantern.ReduceMetric(lambda state, loss: loss.item()),
58+
loss=lantern.Metric().reduce(lambda state, loss: dict(loss=loss.item())),
5959
)
6060

6161
for epoch in lantern.Epochs(2):
@@ -73,14 +73,14 @@ def test_mnist():
7373
train_metrics["loss"].update_(loss)
7474
sleep(0.5)
7575

76-
for name, metric in train_metrics.items():
77-
metric.log(tensorboard_logger, "train", name, epoch)
76+
for metrics in train_metrics.values():
77+
metrics.log(tensorboard_logger, "train", epoch)
7878

7979
print(lantern.MetricTable("train", train_metrics))
8080

8181
evaluate_metrics = {
8282
name: dict(
83-
loss=lantern.MapMetric(lambda loss: loss.item()),
83+
loss=lantern.Metric().reduce(lambda state, loss: dict(loss=loss.item())),
8484
)
8585
for name in evaluate_data_loaders
8686
}
@@ -93,13 +93,13 @@ def test_mnist():
9393

9494
evaluate_metrics[name]["loss"].update_(loss)
9595

96-
for metric_name, metric in evaluate_metrics[name].items():
97-
metric.log(tensorboard_logger, name, metric_name, epoch)
96+
for metrics in evaluate_metrics[name].values():
97+
metrics.log(tensorboard_logger, name, epoch)
9898

9999
print(lantern.MetricTable(name, evaluate_metrics[name]))
100100

101101
early_stopping = early_stopping.score(
102-
-evaluate_metrics["evaluate_early_stopping"]["loss"].compute()
102+
-evaluate_metrics["evaluate_early_stopping"]["loss"].compute()["loss"]
103103
)
104104
if early_stopping.scores_since_improvement == 0:
105105
torch.save(model.state_dict(), "model.pt")

0 commit comments

Comments
 (0)