1
1
import numpy as np
2
+ import functools
2
3
from lantern import FunctionalBase
3
4
from lantern .functional import star
4
5
from typing import Callable , Any , Optional , Dict , List , Union
@@ -12,38 +13,44 @@ class MapMetric(FunctionalBase):
12
13
class Config :
13
14
arbitrary_types_allowed = True
14
15
allow_mutation = True
15
- extra = Extra .forbid
16
16
17
- def __init__ (self , map_fn = None , state = list ()):
17
+ def __init__ (self , state = list (), map_fn = None ):
18
18
super ().__init__ (
19
- map_fn = map_fn ,
20
19
state = state ,
20
+ map_fn = map_fn ,
21
21
)
22
22
23
23
def map (self , fn ):
24
24
if self .map_fn is None :
25
- return self . replace ( map_fn = fn )
25
+ map_fn = fn
26
26
else :
27
- return self .replace (
28
- map_fn = lambda * args , ** kwargs : fn (self .map_fn (* args , ** kwargs ))
29
- )
27
+
28
+ def map_fn (* args , ** kwargs ):
29
+ return fn (self .map_fn (* args , ** kwargs ))
30
+
31
+ return self .replace (
32
+ map_fn = map_fn ,
33
+ state = list (map (fn , self .state )),
34
+ )
30
35
31
36
def starmap (self , fn ):
32
37
return self .map (star (fn ))
33
38
34
- def reduce (self , fn ):
39
+ def reduce (self , fn , initial = None ):
35
40
if self .map_fn is None :
36
- return ReduceMetric (
37
- map_fn = lambda * args : args ,
38
- reduce_fn = lambda state , args : fn (state , * args ),
39
- state = self .state , # TODO: apply function on state...
40
- )
41
+
42
+ def reduce_fn (state , * args ):
43
+ return fn (state , * args )
44
+
41
45
else :
42
- return ReduceMetric (
43
- map_fn = self .map_fn ,
44
- reduce_fn = fn ,
45
- state = self .state ,
46
- )
46
+
47
+ def reduce_fn (state , args ):
48
+ return fn (state , self .map_fn (args ))
49
+
50
+ return ReduceMetric (
51
+ reduce_fn = reduce_fn ,
52
+ state = functools .reduce (reduce_fn , self .state , initial ),
53
+ )
47
54
48
55
def aggregate (self , fn ):
49
56
return AggregateMetric (metric = self , aggregate_fn = fn )
@@ -86,32 +93,30 @@ def log_dict(self, tensorboard_logger, tag, step=None):
86
93
)
87
94
return self
88
95
96
+ def __call__ (self ):
97
+ return self .compute ()
98
+
99
+ def __iter__ (self ):
100
+ return iter (self .compute ())
101
+
89
102
90
103
Metric = MapMetric
91
104
92
105
93
106
class ReduceMetric (FunctionalBase ):
94
- map_fn : Callable [..., Any ]
95
107
reduce_fn : Callable [..., Any ]
96
108
state : Any
97
109
98
110
class Config :
99
111
arbitrary_types_allowed = True
100
112
allow_mutation = True
101
113
102
- def replace (self , ** kwargs ):
103
- new_dict = self .dict ()
104
- new_dict .update (** kwargs )
105
- return type (self )(** new_dict )
106
-
107
114
def update_ (self , * args , ** kwargs ):
108
- self .state = self .reduce_fn (self .state , self . map_fn ( * args , ** kwargs ) )
115
+ self .state = self .reduce_fn (self .state , * args , ** kwargs )
109
116
return self
110
117
111
118
def update (self , * args , ** kwargs ):
112
- return self .replace (
113
- state = self .reduce_fn (self .state , self .map_fn (* args , ** kwargs ))
114
- )
119
+ return self .replace (state = self .reduce_fn (self .state , * args , ** kwargs ))
115
120
116
121
def compute (self ):
117
122
return self .state
@@ -176,5 +181,46 @@ def log_dict(self, tensorboard_logger, tag, step=None):
176
181
return self
177
182
178
183
179
- def test_metric ():
180
- pass
184
+ def test_map_update ():
185
+ assert Metric ().map (lambda x : x * 2 ).update (2 ).compute () == [4 ]
186
+
187
+
188
+ def test_map_after_update ():
189
+ assert Metric ().update (2 ).map (lambda x : x * 2 ).compute () == [4 ]
190
+
191
+
192
+ def test_reduce ():
193
+ assert Metric ([2 , 3 ]).reduce (lambda state , x : state + x , initial = 0 ).compute () == 5
194
+
195
+
196
+ def test_update_after_reduce ():
197
+ assert (
198
+ Metric ([2 , 3 ]).reduce (lambda state , x : state + x , initial = 0 ).update (2 ).compute ()
199
+ == 7
200
+ )
201
+
202
+
203
+ def test_aggregate ():
204
+ assert Metric ([2 , 3 , 4 ]).aggregate (lambda xs : np .mean (xs )).compute () == 3
205
+
206
+
207
+ def test_map_after_aggregate ():
208
+ assert (
209
+ Metric ([2 , 3 , 4 ])
210
+ .aggregate (lambda xs : np .mean (xs ))
211
+ .map (lambda x : x ** 2 )
212
+ .compute ()
213
+ == 9
214
+ )
215
+
216
+
217
+ def test_update_last ():
218
+ assert (
219
+ Metric ()
220
+ .aggregate (lambda xs : np .mean (xs ))
221
+ .map (lambda x : x ** 2 )
222
+ .update (2 )
223
+ .update (3 )
224
+ .compute ()
225
+ == 2.5 ** 2
226
+ )
0 commit comments