Skip to content

Commit a8f118c

Browse files
committed
Add EditDistance to evaluator.py
1 parent 680aec2 commit a8f118c

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

python/paddle/v2/fluid/evaluator.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -218,21 +218,23 @@ def __init__(self, input, label, k=1, **kwargs):
218218
raise ValueError("You can only invoke Evaluator in root block")
219219

220220
self.total_error = self.create_state(
221-
dtype='int64', shape=[1], suffix='total')
222-
self.batch_num = 0
221+
dtype='float32', shape=[1], suffix='total')
222+
self.batch_num = self.create_state(
223+
dtype='float32', shape=[1], suffix='total')
223224
error = layers.edit_distance(input=input, label=label)
224-
mean_error = layers.mean(input=error)
225+
error = layers.cast(x=error, dtype='float32')
226+
mean_error = layers.mean(x=error)
225227
layers.sums(input=[self.total_error, mean_error], out=self.total_error)
228+
const1 = layers.fill_constant(shape=[1], value=1.0, dtype="float32")
229+
layers.sums(input=[self.batch_num, const1], out=self.batch_num)
226230
self.metrics.append(mean_error)
227231

228232
def eval(self, executor, eval_program=None):
229-
self.batch_num += 1
230233
if eval_program is None:
231234
eval_program = Program()
232235
block = eval_program.current_block()
233236
with program_guard(main_program=eval_program):
234237
total_error = _clone_var_(block, self.total_error)
235-
batch_num = layers.fill_constant(
236-
shape=[1], value=self.batch_num, dtype="float32")
238+
batch_num = _clone_var_(block, self.batch_num)
237239
out = layers.elementwise_div(x=total_error, y=batch_num)
238240
return np.array(executor.run(eval_program, fetch_list=[out])[0])

0 commit comments

Comments
 (0)