@@ -218,21 +218,23 @@ def __init__(self, input, label, k=1, **kwargs):
218
218
raise ValueError ("You can only invoke Evaluator in root block" )
219
219
220
220
self .total_error = self .create_state (
221
- dtype = 'int64' , shape = [1 ], suffix = 'total' )
222
- self .batch_num = 0
221
+ dtype = 'float32' , shape = [1 ], suffix = 'total' )
222
+ self .batch_num = self .create_state (
223
+ dtype = 'float32' , shape = [1 ], suffix = 'total' )
223
224
error = layers .edit_distance (input = input , label = label )
224
- mean_error = layers .mean (input = error )
225
+ error = layers .cast (x = error , dtype = 'float32' )
226
+ mean_error = layers .mean (x = error )
225
227
layers .sums (input = [self .total_error , mean_error ], out = self .total_error )
228
+ const1 = layers .fill_constant (shape = [1 ], value = 1.0 , dtype = "float32" )
229
+ layers .sums (input = [self .batch_num , const1 ], out = self .batch_num )
226
230
self .metrics .append (mean_error )
227
231
228
232
def eval (self , executor , eval_program = None ):
229
- self .batch_num += 1
230
233
if eval_program is None :
231
234
eval_program = Program ()
232
235
block = eval_program .current_block ()
233
236
with program_guard (main_program = eval_program ):
234
237
total_error = _clone_var_ (block , self .total_error )
235
- batch_num = layers .fill_constant (
236
- shape = [1 ], value = self .batch_num , dtype = "float32" )
238
+ batch_num = _clone_var_ (block , self .batch_num )
237
239
out = layers .elementwise_div (x = total_error , y = batch_num )
238
240
return np .array (executor .run (eval_program , fetch_list = [out ])[0 ])
0 commit comments