Skip to content

Commit 2c85e19

Browse files
committed
refactor: simplify reduce and add some tests
1 parent fb6ba08 commit 2c85e19

File tree

1 file changed

+76
-30
lines changed

1 file changed

+76
-30
lines changed

lantern/metric.py

Lines changed: 76 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import functools
23
from lantern import FunctionalBase
34
from lantern.functional import star
45
from typing import Callable, Any, Optional, Dict, List, Union
@@ -12,38 +13,44 @@ class MapMetric(FunctionalBase):
1213
class Config:
1314
arbitrary_types_allowed = True
1415
allow_mutation = True
15-
extra = Extra.forbid
1616

17-
def __init__(self, map_fn=None, state=list()):
17+
def __init__(self, state=list(), map_fn=None):
1818
super().__init__(
19-
map_fn=map_fn,
2019
state=state,
20+
map_fn=map_fn,
2121
)
2222

2323
def map(self, fn):
2424
if self.map_fn is None:
25-
return self.replace(map_fn=fn)
25+
map_fn = fn
2626
else:
27-
return self.replace(
28-
map_fn=lambda *args, **kwargs: fn(self.map_fn(*args, **kwargs))
29-
)
27+
28+
def map_fn(*args, **kwargs):
29+
return fn(self.map_fn(*args, **kwargs))
30+
31+
return self.replace(
32+
map_fn=map_fn,
33+
state=list(map(fn, self.state)),
34+
)
3035

3136
def starmap(self, fn):
3237
return self.map(star(fn))
3338

34-
def reduce(self, fn):
39+
def reduce(self, fn, initial=None):
3540
if self.map_fn is None:
36-
return ReduceMetric(
37-
map_fn=lambda *args: args,
38-
reduce_fn=lambda state, args: fn(state, *args),
39-
state=self.state, # TODO: apply function on state...
40-
)
41+
42+
def reduce_fn(state, *args):
43+
return fn(state, *args)
44+
4145
else:
42-
return ReduceMetric(
43-
map_fn=self.map_fn,
44-
reduce_fn=fn,
45-
state=self.state,
46-
)
46+
47+
def reduce_fn(state, args):
48+
return fn(state, self.map_fn(args))
49+
50+
return ReduceMetric(
51+
reduce_fn=reduce_fn,
52+
state=functools.reduce(reduce_fn, self.state, initial),
53+
)
4754

4855
def aggregate(self, fn):
4956
return AggregateMetric(metric=self, aggregate_fn=fn)
@@ -86,32 +93,30 @@ def log_dict(self, tensorboard_logger, tag, step=None):
8693
)
8794
return self
8895

96+
def __call__(self):
97+
return self.compute()
98+
99+
def __iter__(self):
100+
return iter(self.compute())
101+
89102

90103
Metric = MapMetric
91104

92105

93106
class ReduceMetric(FunctionalBase):
94-
map_fn: Callable[..., Any]
95107
reduce_fn: Callable[..., Any]
96108
state: Any
97109

98110
class Config:
99111
arbitrary_types_allowed = True
100112
allow_mutation = True
101113

102-
def replace(self, **kwargs):
103-
new_dict = self.dict()
104-
new_dict.update(**kwargs)
105-
return type(self)(**new_dict)
106-
107114
def update_(self, *args, **kwargs):
108-
self.state = self.reduce_fn(self.state, self.map_fn(*args, **kwargs))
115+
self.state = self.reduce_fn(self.state, *args, **kwargs)
109116
return self
110117

111118
def update(self, *args, **kwargs):
112-
return self.replace(
113-
state=self.reduce_fn(self.state, self.map_fn(*args, **kwargs))
114-
)
119+
return self.replace(state=self.reduce_fn(self.state, *args, **kwargs))
115120

116121
def compute(self):
117122
return self.state
@@ -176,5 +181,46 @@ def log_dict(self, tensorboard_logger, tag, step=None):
176181
return self
177182

178183

179-
def test_metric():
180-
pass
184+
def test_map_update():
185+
assert Metric().map(lambda x: x * 2).update(2).compute() == [4]
186+
187+
188+
def test_map_after_update():
189+
assert Metric().update(2).map(lambda x: x * 2).compute() == [4]
190+
191+
192+
def test_reduce():
193+
assert Metric([2, 3]).reduce(lambda state, x: state + x, initial=0).compute() == 5
194+
195+
196+
def test_update_after_reduce():
197+
assert (
198+
Metric([2, 3]).reduce(lambda state, x: state + x, initial=0).update(2).compute()
199+
== 7
200+
)
201+
202+
203+
def test_aggregate():
204+
assert Metric([2, 3, 4]).aggregate(lambda xs: np.mean(xs)).compute() == 3
205+
206+
207+
def test_map_after_aggregate():
208+
assert (
209+
Metric([2, 3, 4])
210+
.aggregate(lambda xs: np.mean(xs))
211+
.map(lambda x: x ** 2)
212+
.compute()
213+
== 9
214+
)
215+
216+
217+
def test_update_last():
218+
assert (
219+
Metric()
220+
.aggregate(lambda xs: np.mean(xs))
221+
.map(lambda x: x ** 2)
222+
.update(2)
223+
.update(3)
224+
.compute()
225+
== 2.5 ** 2
226+
)

0 commit comments

Comments
 (0)