1
- from paddle .v2 .framework .framework import Program , g_main_program , unique_name
2
- from paddle .v2 .framework .layer_helper import LayerHelper
1
+ from paddle .v2 .framework .framework import Program , g_main_program , unique_name , Variable
3
2
import paddle .v2 .framework .core as core
4
3
5
4
5
+ def _clone_var_in_block_ (block , var ):
6
+ assert isinstance (var , Variable )
7
+ return block .create_var (
8
+ name = var .name ,
9
+ shape = var .shape ,
10
+ dtype = var .data_type ,
11
+ type = var .type ,
12
+ lod_level = var .lod_level ,
13
+ persistable = True )
14
+
15
+
6
16
class Evaluator (object ):
7
17
"""
8
18
Evalutor Base class.
@@ -13,33 +23,49 @@ class Evaluator(object):
13
23
"""
14
24
15
25
def __init__ (self , name , ** kwargs ):
26
+ """
27
+ init the global states
28
+ """
16
29
self ._states = {}
17
- if kwargs .has_key ("program" ):
18
- self ._program = kwargs .get ("program" )
30
+ if kwargs .has_key ("main_program" ):
31
+ self ._main_program = kwargs .get ("main_program" )
32
+ else :
33
+ self ._main_program = g_main_program
34
+ if kwargs .has_key ("eval_program" ):
35
+ self ._eval_program = kwargs .get ("eval_program" )
19
36
else :
20
- self ._program = g_main_program
37
+ self ._eval_program = Program ()
38
+
39
+ def _update_ops (self ):
40
+ """
41
+ append update ops to the global states
42
+ """
43
+ raise NotImplementedError ()
21
44
22
45
def reset (self , executor , program = None ):
23
46
"""
24
- Clear metric states at the begin of each pass/user specified batch
25
- """
47
+ Clear metric states at the begin of each pass/user specified batch
48
+ """
26
49
if program == None :
27
50
reset_program = Program ()
28
51
else :
29
52
reset_program = program
30
53
block = reset_program .global_block ()
31
54
for k , var in self ._states .iteritems ():
32
- zeros = block .create_var (dtype = var .data_type )
55
+ g_var = _clone_var_in_block_ (block , var )
56
+ zeros = block .create_var (dtype = "float32" , persistable = True )
33
57
block .append_op (
34
58
type = "fill_constant" ,
35
59
outputs = {"Out" : [zeros ]},
36
60
attrs = {
37
- "shape" : var .shape ,
38
- "value" : 0 ,
61
+ "shape" : g_var .shape ,
62
+ "value" : .0 ,
63
+ "data_type" : 5 ,
39
64
})
40
65
block .append_op (
41
- type = "scale" , inputs = {"X" : zeros }, outputs = {"Out" : var })
42
- executor .run (reset_program )
66
+ type = "scale" , inputs = {"X" : zeros }, outputs = {"Out" : g_var })
67
+ print reset_program
68
+ executor .run (reset_program , fetch_list = self ._states .values ())
43
69
44
70
def eval (self , executor , program = None ):
45
71
"""
@@ -53,22 +79,25 @@ class Accuracy(Evaluator):
53
79
Accuracy need two state variable Total, Correct
54
80
"""
55
81
56
- def __init__ (self , input , label , k = 1 , ** kwargs ):
82
+ def __init__ (self , * args , ** kwargs ):
57
83
super (Accuracy , self ).__init__ ("accuracy" , ** kwargs )
58
- block = self ._program .global_block ()
84
+ # block = self._eval_program.global_block()
85
+ block = self ._main_program .global_block ()
59
86
g_total = block .create_var (
60
87
name = unique_name ("Total" ),
61
88
persistable = True ,
62
89
dtype = "int64" ,
63
90
shape = [1 ])
64
- g_correct = helper . create_global_variable (
91
+ g_correct = block . create_var (
65
92
name = unique_name ("Correct" ),
66
93
persistable = True ,
67
94
dtype = "int64" ,
68
95
shape = [1 ])
69
96
self ._states ["Total" ] = g_total
70
97
self ._states ["Correct" ] = g_correct
71
98
99
+ def _update_ops (self , input , label , k = 1 , ** kwargs ):
100
+ block = self ._main_program .global_block ()
72
101
topk_out = block .create_var (dtype = input .data_type )
73
102
topk_indices = block .create_var (dtype = "int64" )
74
103
block .append_op (
@@ -77,8 +106,9 @@ def __init__(self, input, label, k=1, **kwargs):
77
106
outputs = {"Out" : [topk_out ],
78
107
"Indices" : [topk_indices ]},
79
108
attrs = {"k" : k })
80
- acc_out_dtype = kwargs .get ("out_dtype" , "float32" )
81
- acc_out = block .create_var (dtype = acc_out_dtype )
109
+ acc_out = block .create_var (dtype = kwargs .get ("out_dtype" , "float32" ))
110
+ correct = block .create_var (dtype = "int64" , persistable = True )
111
+ total = block .create_var (dtype = "int64" , persistable = True )
82
112
block .append_op (
83
113
type = "accuracy" ,
84
114
inputs = {
@@ -92,39 +122,121 @@ def __init__(self, input, label, k=1, **kwargs):
92
122
"Total" : [total ],
93
123
})
94
124
125
+ # block = self._eval_program.global_block()
126
+ # e_correct = _clone_var_in_block_(block, correct)
127
+ # e_total = _clone_var_in_block_(block, total)
128
+
129
+ # block.append_op(
130
+ # type="sum",
131
+ # inputs={"X": [self._states["Total"], total]},
132
+ # outputs={"Out": [self._states["Total"]]})
133
+ block .append_op (
134
+ type = "cast" ,
135
+ inputs = {"X" : [self ._states ["Total" ]]},
136
+ outputs = {"Out" : [self ._states ["Total" ]]},
137
+ attrs = {
138
+ "in_data_type" : 5 ,
139
+ "out_data_type" : 2 ,
140
+ })
141
+ block .append_op (
142
+ type = "cast" ,
143
+ inputs = {"X" : [self ._states ["Correct" ]]},
144
+ outputs = {"Out" : [self ._states ["Correct" ]]},
145
+ attrs = {
146
+ "in_data_type" : 5 ,
147
+ "out_data_type" : 2 ,
148
+ })
149
+
95
150
block .append_op (
96
- type = "sum" ,
97
- inputs = {"X" : [g_total , total ]},
98
- outputs = {"Out" : [g_total ]})
151
+ type = "elementwise_add" ,
152
+ inputs = {"X" : [self ._states ["Total" ]],
153
+ "Y" : [total ]},
154
+ outputs = {"Out" : [self ._states ["Total" ]]})
99
155
block .append_op (
100
- type = "sum" ,
101
- inputs = {"X" : [g_correct , correct ]},
102
- outputs = {"Out" : [g_total ]})
156
+ type = "elementwise_add" ,
157
+ inputs = {"X" : [self ._states ["Correct" ]],
158
+ "Y" : [correct ]},
159
+ outputs = {"Out" : [self ._states ["Correct" ]]})
160
+
161
+ # g_total = self._states["Total"]
162
+ # print g_total
163
+ # print total
164
+
165
+ # print "*" * 100
166
+ # print g_total.block.program == total.block.program
167
+
168
+ # g_total = _clone_var_in_block_(block, self._states["Total"])
169
+ # e_total = _clone_var_in_block_(block, total)
170
+
171
+ # block.append_op(
172
+ # type="sum",
173
+ # inputs={"X": [g_total, e_total]},
174
+ # outputs={"Out": [g_total]})
175
+
176
+ # block.append_op(
177
+ # type="sum",
178
+ # inputs={"X": [self._states["Correct"], correct]},
179
+ # outputs={"Out": [self._states["Correct"]]})
180
+ # print self._main_program
103
181
return acc_out
104
182
105
- def eval (self , executor , program = None ):
106
- if program == None :
107
- eval_program = Program ()
108
- else :
109
- eval_program = program
110
- block = eval_program .global_block ()
111
- eval_out = block .create_var (dtype = self ._helper .input_dtype ())
183
+ def eval (self , executor ):
184
+ block = self ._eval_program .global_block ()
185
+ eval_out = block .create_var (dtype = self ._states ["Total" ].data_type )
186
+ e_correct = _clone_var_in_block_ (block , correct )
187
+ e_total = _clone_var_in_block_ (block , total )
188
+ # block.append_op(
189
+ # type="elementwise_div",
190
+ # inputs={"X": self._states["Total"],
191
+ # "Y": self._states["Correct"]},
192
+ # outputs={"Out": eval_out})
112
193
block .append_op (
113
194
type = "elementwise_div" ,
114
- inputs = {"X" : self . _states [ "Total" ] ,
115
- "Y" : self . _states [ "Correct" ] },
195
+ inputs = {"X" : e_total ,
196
+ "Y" : e_correct },
116
197
outputs = {"Out" : eval_out })
117
- return executor .run (eval_program , fetch_list = [eval_out ])
198
+ return executor .run (self . _eval_program , fetch_list = [eval_out ])
118
199
119
200
120
- # Demo for composing low level op to compute the F1 metric
121
- class F1 (Evaluator ):
122
- def __init__ (self , input , label , ** kwargs ):
123
- super (F1 , self ).__init__ ("F1" , ** kwargs )
124
- g_tp = helper .create_global_variable (
201
+ # Demo for composing low level ops to compute the F1 metric
202
+ class FScore (Evaluator ):
203
+ def __init__ (self , input , label , beta = 1.0 , ** kwargs ):
204
+ super (F1 , self ).__init__ ("FScore" , ** kwargs )
205
+ block = self ._program .global_block ()
206
+ g_tp = block .create_var (
125
207
name = unique_name ("Tp" ), persistable = True , dtype = "int64" , shape = [1 ])
126
- g_fp = helper .create_global_variable (
208
+ g_fn = block .create_var (
209
+ name = unique_name ("Fn" ), persistable = True , dtype = "int64" , shape = [1 ])
210
+ g_fp = block .create_var (
127
211
name = unique_name ("Fp" ), persistable = True , dtype = "int64" , shape = [1 ])
128
212
129
213
self ._states ["Tp" ] = g_tp
130
214
self ._states ["Fp" ] = g_fp
215
+ self ._states ["Fn" ] = g_fn
216
+
217
+ def _update_ops (self ):
218
+ block = self ._program .global_block ()
219
+ equal_out = block .create_var ()
220
+ block .append_op (
221
+ type = "equal" ,
222
+ inputs = {"X" : [input ],
223
+ "Y" : [label ]},
224
+ outputs = {"Out" : equal_out })
225
+
226
+ positive = block .create_var ()
227
+ block .append_op (
228
+ type = "sequence_pool" ,
229
+ inputs = {"X" : [equal_out ]},
230
+ outputs = {"Out" : positive },
231
+ attrs = {"pooltype" : "SUM" })
232
+ batch = block .create_var (
233
+ name = feed_var_name ,
234
+ type = core .VarDesc .VarType .FEED_MINIBATCH ,
235
+ persistable = True )
236
+
237
+
238
+ # def register():
239
+ accuracy = Accuracy
240
+ # def accuracy(*args, **kwargs):
241
+ # acc = Accuracy(**kwargs)
242
+ # return acc._update_ops(*args, **kwargs)
0 commit comments