Skip to content

Commit fb6ba08

Browse files
committed
improve: handle single value and dict metrics
1 parent f7aa22b commit fb6ba08

File tree

4 files changed

+75
-63
lines changed

4 files changed

+75
-63
lines changed

lantern/metric.py

Lines changed: 51 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -5,54 +5,42 @@
55
from pydantic import BaseModel, Extra
66

77

8-
class MapMetric(BaseModel):
9-
map_fn_: Optional[Callable[..., Any]]
10-
# map_fn: Optional[Callable] = lambda value: self, value # HACK: why are we getting self?
8+
class MapMetric(FunctionalBase):
9+
map_fn: Optional[Callable[..., Any]]
1110
state: List[Any]
1211

1312
class Config:
1413
arbitrary_types_allowed = True
1514
allow_mutation = True
1615
extra = Extra.forbid
1716

18-
def __init__(self, map_fn_=None, state=list()):
17+
def __init__(self, map_fn=None, state=list()):
1918
super().__init__(
20-
map_fn_=map_fn_,
19+
map_fn=map_fn,
2120
state=state,
2221
)
2322

24-
def replace(self, **kwargs):
25-
new_dict = self.dict()
26-
new_dict.update(**kwargs)
27-
return type(self)(**new_dict)
28-
2923
def map(self, fn):
30-
# return self.replace(fn=lambda value: fn(self.map_fn_(value)))
31-
# HACK: why doesn't the above work?
32-
if self.map_fn_ is None:
33-
return MapMetric(
34-
map_fn_=fn,
35-
state=self.state,
36-
)
24+
if self.map_fn is None:
25+
return self.replace(map_fn=fn)
3726
else:
38-
return MapMetric(
39-
map_fn_=lambda *args, **kwargs: fn(self.map_fn_(*args, **kwargs)),
40-
state=self.state,
27+
return self.replace(
28+
map_fn=lambda *args, **kwargs: fn(self.map_fn(*args, **kwargs))
4129
)
4230

4331
def starmap(self, fn):
4432
return self.map(star(fn))
4533

4634
def reduce(self, fn):
47-
if self.map_fn_ is None:
35+
if self.map_fn is None:
4836
return ReduceMetric(
49-
map_fn_=lambda *args: args,
37+
map_fn=lambda *args: args,
5038
reduce_fn=lambda state, args: fn(state, *args),
5139
state=self.state, # TODO: apply function on state...
5240
)
5341
else:
5442
return ReduceMetric(
55-
map_fn_=self.map_fn_,
43+
map_fn=self.map_fn,
5644
reduce_fn=fn,
5745
state=self.state,
5846
)
@@ -64,22 +52,32 @@ def staraggregate(self, fn):
6452
return self.aggregate(star(fn))
6553

6654
def update_(self, *args, **kwargs):
67-
if self.map_fn_ is None:
55+
if self.map_fn is None:
6856
self.state.append(args)
6957
else:
70-
self.state.append(self.map_fn_(*args, **kwargs))
58+
self.state.append(self.map_fn(*args, **kwargs))
7159
return self
7260

7361
def update(self, *args, **kwargs):
74-
if self.map_fn_ is None:
75-
return self.replace(state=self.state + ([args[0]] if len(args) == 1 else [args]))
62+
if self.map_fn is None:
63+
return self.replace(
64+
state=self.state + ([args[0]] if len(args) == 1 else [args])
65+
)
7666
else:
77-
return self.replace(state=self.state + [self.map_fn_(*args, **kwargs)])
67+
return self.replace(state=self.state + [self.map_fn(*args, **kwargs)])
7868

7969
def compute(self):
8070
return self.state
8171

82-
def log(self, tensorboard_logger, tag, step=None):
72+
def log(self, tensorboard_logger, tag, metric_name, step=None):
73+
tensorboard_logger.add_scalar(
74+
f"{tag}/{metric_name}",
75+
self.compute(),
76+
step,
77+
)
78+
return self
79+
80+
def log_dict(self, tensorboard_logger, tag, step=None):
8381
for name, value in self.compute().items():
8482
tensorboard_logger.add_scalar(
8583
f"{tag}/{name}",
@@ -92,34 +90,41 @@ def log(self, tensorboard_logger, tag, step=None):
9290
Metric = MapMetric
9391

9492

95-
class ReduceMetric(BaseModel):
96-
map_fn_: Callable[..., Any]
93+
class ReduceMetric(FunctionalBase):
94+
map_fn: Callable[..., Any]
9795
reduce_fn: Callable[..., Any]
9896
state: Any
9997

10098
class Config:
10199
arbitrary_types_allowed = True
102100
allow_mutation = True
103-
extra = Extra.forbid
104101

105102
def replace(self, **kwargs):
106103
new_dict = self.dict()
107104
new_dict.update(**kwargs)
108105
return type(self)(**new_dict)
109106

110107
def update_(self, *args, **kwargs):
111-
self.state = self.reduce_fn(self.state, self.map_fn_(*args, **kwargs))
108+
self.state = self.reduce_fn(self.state, self.map_fn(*args, **kwargs))
112109
return self
113110

114111
def update(self, *args, **kwargs):
115112
return self.replace(
116-
state=self.reduce_fn(self.state, self.map_fn_(*args, **kwargs))
113+
state=self.reduce_fn(self.state, self.map_fn(*args, **kwargs))
117114
)
118115

119116
def compute(self):
120117
return self.state
121118

122-
def log(self, tensorboard_logger, tag, step=None):
119+
def log(self, tensorboard_logger, tag, metric_name, step=None):
120+
tensorboard_logger.add_scalar(
121+
f"{tag}/{metric_name}",
122+
self.compute(),
123+
step,
124+
)
125+
return self
126+
127+
def log_dict(self, tensorboard_logger, tag, step=None):
123128
for name, value in self.compute().items():
124129
tensorboard_logger.add_scalar(
125130
f"{tag}/{name}",
@@ -129,24 +134,16 @@ def log(self, tensorboard_logger, tag, step=None):
129134
return self
130135

131136

132-
class AggregateMetric(BaseModel):
137+
class AggregateMetric(FunctionalBase):
133138
metric: Union[MapMetric, ReduceMetric]
134139
aggregate_fn: Callable
135140

136141
class Config:
137142
arbitrary_types_allowed = True
138143
allow_mutation = True
139-
extra = Extra.forbid
140-
141-
def replace(self, **kwargs):
142-
new_dict = self.dict()
143-
new_dict.update(**kwargs)
144-
return type(self)(**new_dict)
145144

146145
def map(self, fn):
147-
return self.replace(
148-
aggregate_fn=lambda state: fn(self.aggregate_fn(state))
149-
)
146+
return self.replace(aggregate_fn=lambda state: fn(self.aggregate_fn(state)))
150147

151148
def starmap(self, fn):
152149
return self.map(star(fn))
@@ -161,7 +158,15 @@ def update(self, *args, **kwargs):
161158
def compute(self):
162159
return self.aggregate_fn(self.metric.compute())
163160

164-
def log(self, tensorboard_logger, tag, step=None):
161+
def log(self, tensorboard_logger, tag, metric_name, step=None):
162+
tensorboard_logger.add_scalar(
163+
f"{tag}/{metric_name}",
164+
self.compute(),
165+
step,
166+
)
167+
return self
168+
169+
def log_dict(self, tensorboard_logger, tag, step=None):
165170
for name, value in self.compute().items():
166171
tensorboard_logger.add_scalar(
167172
f"{tag}/{name}",

lantern/metric_table.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,14 @@ def __init__(self, name, metrics):
2121
)
2222

2323
def compute(self):
24-
return {
25-
metric_name: value
26-
for metrics in self.metrics.values()
27-
for metric_name, value in metrics.compute().items()
28-
}
24+
log_dict = dict()
25+
for metric_name, metric in self.metrics.items():
26+
metric_value = metric.compute()
27+
if isinstance(metric_value, dict):
28+
log_dict.update(**metric_value)
29+
else:
30+
log_dict[metric_name] = metric_value
31+
return log_dict
2932

3033
def table(self):
3134
return "\n".join(

lantern/progress_bar.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@ def ProgressBar(data_loader, name, metrics: Optional[Dict[str, Metric]] = None):
1212
with tqdm(data_loader, desc=name, leave=False) as tqdm_:
1313
for item in tqdm_:
1414
yield item
15-
tqdm_.set_postfix(
16-
{
17-
name: value
18-
for metrics in metrics.values()
19-
for name, value in metrics.compute().items()
20-
}
21-
)
15+
16+
log_dict = dict()
17+
for metric_name, metric in metrics.items():
18+
metric_value = metric.compute()
19+
if isinstance(metric_value, dict):
20+
log_dict.update(**metric_value)
21+
else:
22+
log_dict[metric_name] = metric_value
23+
tqdm_.set_postfix(log_dict)

test/test_mnist.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def test_mnist():
5555
tensorboard_logger = torch.utils.tensorboard.SummaryWriter()
5656
early_stopping = lantern.EarlyStopping(tensorboard_logger=tensorboard_logger)
5757
train_metrics = dict(
58-
loss=lantern.Metric().reduce(lambda state, loss: dict(loss=loss.item())),
58+
loss=lantern.Metric().reduce(lambda state, loss: loss.item()),
5959
)
6060

6161
for epoch in lantern.Epochs(2):
@@ -73,14 +73,16 @@ def test_mnist():
7373
train_metrics["loss"].update_(loss)
7474
sleep(0.5)
7575

76-
for metrics in train_metrics.values():
77-
metrics.log(tensorboard_logger, "train", epoch)
76+
for metric_name, metric in train_metrics.items():
77+
metric.log(tensorboard_logger, "train", metric_name, epoch)
7878

7979
print(lantern.MetricTable("train", train_metrics))
8080

8181
evaluate_metrics = {
8282
name: dict(
83-
loss=lantern.Metric().reduce(lambda state, loss: dict(loss=loss.item())),
83+
loss=lantern.Metric().reduce(
84+
lambda state, loss: dict(loss=loss.item())
85+
),
8486
)
8587
for name in evaluate_data_loaders
8688
}
@@ -94,7 +96,7 @@ def test_mnist():
9496
evaluate_metrics[name]["loss"].update_(loss)
9597

9698
for metrics in evaluate_metrics[name].values():
97-
metrics.log(tensorboard_logger, name, epoch)
99+
metrics.log_dict(tensorboard_logger, name, epoch)
98100

99101
print(lantern.MetricTable(name, evaluate_metrics[name]))
100102

0 commit comments

Comments
 (0)