Skip to content

Commit c321b93

Browse files
committed
improve: implement logging metrics to tensorboard
1 parent 57484d1 commit c321b93

File tree

2 files changed

+23
-13
lines changed

2 files changed

+23
-13
lines changed

lantern/metric.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,25 @@
22

33

44
class ReduceMetric:
5-
def __init__(self, reduce, compute, initial_state=None):
6-
self._reduce = reduce
7-
self._compute = compute
5+
def __init__(self, reduce_fn, compute_fn, initial_state=None):
6+
self.reduce_fn = reduce_fn
7+
self.compute_fn = compute_fn
88
self.state = initial_state
99

1010
def reduce(self, *args, **kwargs):
1111
return ReduceMetric(
12-
reduce=self._reduce,
13-
compute=self._compute,
14-
initial_state=self._reduce(self.state, *args, **kwargs),
12+
reduce_fn=self.reduce_fn,
13+
compute_fn=self.compute_fn,
14+
initial_state=self.reduce_fn(self.state, *args, **kwargs),
1515
)
1616

1717
def compute(self):
18-
return self._compute(self.state)
18+
return self.compute_fn(self.state)
1919

2020

21-
def MapMetric(map, compute=np.mean):
21+
def MapMetric(map_fn, compute_fn=np.mean):
2222
return ReduceMetric(
23-
reduce=lambda state, *args, **kwargs: state + [map(*args, **kwargs)],
24-
compute=compute,
23+
reduce_fn=lambda state, *args, **kwargs: state + [map_fn(*args, **kwargs)],
24+
compute_fn=compute_fn,
2525
initial_state=list(),
2626
)

lantern/metrics.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,11 @@ def compute(self):
1818

1919

2020
class Metrics:
21-
def __init__(self, name, tensorboard_logger, metrics):
21+
def __init__(self, name, tensorboard_logger, metrics, n_logs=0):
2222
self.name = name
2323
self.tensorboard_logger = tensorboard_logger
2424
self.metrics = metrics
25+
self.n_logs = n_logs
2526

2627
def __getitem__(self, name_or_names):
2728
if type(name_or_names) == str:
@@ -39,8 +40,17 @@ def update_(self, *args, **kwargs):
3940
def compute(self):
4041
return {name: metric.compute() for name, metric in self.metrics.items()}
4142

42-
def log_(self):
43-
# TODO
43+
def log_(self, step=None):
44+
self.n_logs += 1
45+
if step is None:
46+
step = self.n_logs
47+
48+
for name, metric in self.metrics.items():
49+
self.tensorboard_logger.add_scalar(
50+
f"{self.name}/{name}",
51+
metric.compute(),
52+
step,
53+
)
4454
return self
4555

4656
def table(self):

0 commit comments

Comments
 (0)