5
5
from pydantic import BaseModel , Extra
6
6
7
7
8
- class MapMetric (BaseModel ):
9
- map_fn_ : Optional [Callable [..., Any ]]
10
- # map_fn: Optional[Callable] = lambda value: self, value # HACK: why are we getting self?
8
+ class MapMetric (FunctionalBase ):
9
+ map_fn : Optional [Callable [..., Any ]]
11
10
state : List [Any ]
12
11
13
12
class Config :
14
13
arbitrary_types_allowed = True
15
14
allow_mutation = True
16
15
extra = Extra .forbid
17
16
18
- def __init__ (self , map_fn_ = None , state = list ()):
17
+ def __init__ (self , map_fn = None , state = list ()):
19
18
super ().__init__ (
20
- map_fn_ = map_fn_ ,
19
+ map_fn = map_fn ,
21
20
state = state ,
22
21
)
23
22
24
- def replace (self , ** kwargs ):
25
- new_dict = self .dict ()
26
- new_dict .update (** kwargs )
27
- return type (self )(** new_dict )
28
-
29
23
def map (self , fn ):
30
- # return self.replace(fn=lambda value: fn(self.map_fn_(value)))
31
- # HACK: why doesn't the above work?
32
- if self .map_fn_ is None :
33
- return MapMetric (
34
- map_fn_ = fn ,
35
- state = self .state ,
36
- )
24
+ if self .map_fn is None :
25
+ return self .replace (map_fn = fn )
37
26
else :
38
- return MapMetric (
39
- map_fn_ = lambda * args , ** kwargs : fn (self .map_fn_ (* args , ** kwargs )),
40
- state = self .state ,
27
+ return self .replace (
28
+ map_fn = lambda * args , ** kwargs : fn (self .map_fn (* args , ** kwargs ))
41
29
)
42
30
43
31
def starmap (self , fn ):
44
32
return self .map (star (fn ))
45
33
46
34
def reduce (self , fn ):
47
- if self .map_fn_ is None :
35
+ if self .map_fn is None :
48
36
return ReduceMetric (
49
- map_fn_ = lambda * args : args ,
37
+ map_fn = lambda * args : args ,
50
38
reduce_fn = lambda state , args : fn (state , * args ),
51
39
state = self .state , # TODO: apply function on state...
52
40
)
53
41
else :
54
42
return ReduceMetric (
55
- map_fn_ = self .map_fn_ ,
43
+ map_fn = self .map_fn ,
56
44
reduce_fn = fn ,
57
45
state = self .state ,
58
46
)
@@ -64,22 +52,32 @@ def staraggregate(self, fn):
64
52
return self .aggregate (star (fn ))
65
53
66
54
def update_ (self , * args , ** kwargs ):
67
- if self .map_fn_ is None :
55
+ if self .map_fn is None :
68
56
self .state .append (args )
69
57
else :
70
- self .state .append (self .map_fn_ (* args , ** kwargs ))
58
+ self .state .append (self .map_fn (* args , ** kwargs ))
71
59
return self
72
60
73
61
def update (self , * args , ** kwargs ):
74
- if self .map_fn_ is None :
75
- return self .replace (state = self .state + ([args [0 ]] if len (args ) == 1 else [args ]))
62
+ if self .map_fn is None :
63
+ return self .replace (
64
+ state = self .state + ([args [0 ]] if len (args ) == 1 else [args ])
65
+ )
76
66
else :
77
- return self .replace (state = self .state + [self .map_fn_ (* args , ** kwargs )])
67
+ return self .replace (state = self .state + [self .map_fn (* args , ** kwargs )])
78
68
79
69
def compute (self ):
80
70
return self .state
81
71
82
- def log (self , tensorboard_logger , tag , step = None ):
72
+ def log (self , tensorboard_logger , tag , metric_name , step = None ):
73
+ tensorboard_logger .add_scalar (
74
+ f"{ tag } /{ metric_name } " ,
75
+ self .compute (),
76
+ step ,
77
+ )
78
+ return self
79
+
80
+ def log_dict (self , tensorboard_logger , tag , step = None ):
83
81
for name , value in self .compute ().items ():
84
82
tensorboard_logger .add_scalar (
85
83
f"{ tag } /{ name } " ,
@@ -92,34 +90,41 @@ def log(self, tensorboard_logger, tag, step=None):
92
90
Metric = MapMetric
93
91
94
92
95
- class ReduceMetric (BaseModel ):
96
- map_fn_ : Callable [..., Any ]
93
+ class ReduceMetric (FunctionalBase ):
94
+ map_fn : Callable [..., Any ]
97
95
reduce_fn : Callable [..., Any ]
98
96
state : Any
99
97
100
98
class Config :
101
99
arbitrary_types_allowed = True
102
100
allow_mutation = True
103
- extra = Extra .forbid
104
101
105
102
def replace (self , ** kwargs ):
106
103
new_dict = self .dict ()
107
104
new_dict .update (** kwargs )
108
105
return type (self )(** new_dict )
109
106
110
107
def update_ (self , * args , ** kwargs ):
111
- self .state = self .reduce_fn (self .state , self .map_fn_ (* args , ** kwargs ))
108
+ self .state = self .reduce_fn (self .state , self .map_fn (* args , ** kwargs ))
112
109
return self
113
110
114
111
def update (self , * args , ** kwargs ):
115
112
return self .replace (
116
- state = self .reduce_fn (self .state , self .map_fn_ (* args , ** kwargs ))
113
+ state = self .reduce_fn (self .state , self .map_fn (* args , ** kwargs ))
117
114
)
118
115
119
116
def compute (self ):
120
117
return self .state
121
118
122
- def log (self , tensorboard_logger , tag , step = None ):
119
+ def log (self , tensorboard_logger , tag , metric_name , step = None ):
120
+ tensorboard_logger .add_scalar (
121
+ f"{ tag } /{ metric_name } " ,
122
+ self .compute (),
123
+ step ,
124
+ )
125
+ return self
126
+
127
+ def log_dict (self , tensorboard_logger , tag , step = None ):
123
128
for name , value in self .compute ().items ():
124
129
tensorboard_logger .add_scalar (
125
130
f"{ tag } /{ name } " ,
@@ -129,24 +134,16 @@ def log(self, tensorboard_logger, tag, step=None):
129
134
return self
130
135
131
136
132
- class AggregateMetric (BaseModel ):
137
+ class AggregateMetric (FunctionalBase ):
133
138
metric : Union [MapMetric , ReduceMetric ]
134
139
aggregate_fn : Callable
135
140
136
141
class Config :
137
142
arbitrary_types_allowed = True
138
143
allow_mutation = True
139
- extra = Extra .forbid
140
-
141
- def replace (self , ** kwargs ):
142
- new_dict = self .dict ()
143
- new_dict .update (** kwargs )
144
- return type (self )(** new_dict )
145
144
146
145
def map (self , fn ):
147
- return self .replace (
148
- aggregate_fn = lambda state : fn (self .aggregate_fn (state ))
149
- )
146
+ return self .replace (aggregate_fn = lambda state : fn (self .aggregate_fn (state )))
150
147
151
148
def starmap (self , fn ):
152
149
return self .map (star (fn ))
@@ -161,7 +158,15 @@ def update(self, *args, **kwargs):
161
158
def compute (self ):
162
159
return self .aggregate_fn (self .metric .compute ())
163
160
164
- def log (self , tensorboard_logger , tag , step = None ):
161
+ def log (self , tensorboard_logger , tag , metric_name , step = None ):
162
+ tensorboard_logger .add_scalar (
163
+ f"{ tag } /{ metric_name } " ,
164
+ self .compute (),
165
+ step ,
166
+ )
167
+ return self
168
+
169
+ def log_dict (self , tensorboard_logger , tag , step = None ):
165
170
for name , value in self .compute ().items ():
166
171
tensorboard_logger .add_scalar (
167
172
f"{ tag } /{ name } " ,
0 commit comments