Skip to content

Commit 5e99706

Browse files
committed
Add sequence error result to edit distance evaluator
1 parent ef8cb8f commit 5e99706

File tree

1 file changed

+26
-9
lines changed

1 file changed

+26
-9
lines changed

python/paddle/v2/fluid/evaluator.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -243,25 +243,42 @@ def __init__(self, input, label, ignored_tokens=None, **kwargs):
243243
if main_program.current_block().idx != 0:
244244
raise ValueError("You can only invoke Evaluator in root block")
245245

246-
self.total_error = self.create_state(
247-
dtype='float32', shape=[1], suffix='total_error')
246+
self.total_distance = self.create_state(
247+
dtype='float32', shape=[1], suffix='total_distance')
248248
self.seq_num = self.create_state(
249249
dtype='int64', shape=[1], suffix='seq_num')
250-
error, seq_num = layers.edit_distance(
250+
self.seq_error = self.create_state(
251+
dtype='int64', shape=[1], suffix='seq_error')
252+
distances, seq_num = layers.edit_distance(
251253
input=input, label=label, ignored_tokens=ignored_tokens)
254+
255+
zero = layers.fill_constant(shape=[1], value=0.0, dtype='float32')
256+
compare_result = layers.equal(distances, zero)
257+
compare_result_int = layers.cast(x=compare_result, dtype='int')
258+
seq_right_count = layers.reduce_sum(compare_result_int)
259+
seq_error_count = layers.elementwise_sub(x=seq_num, y=seq_right_count)
252260
#error = layers.cast(x=error, dtype='float32')
253-
sum_error = layers.reduce_sum(error)
254-
layers.sums(input=[self.total_error, sum_error], out=self.total_error)
261+
total_distance = layers.reduce_sum(distances)
262+
layers.sums(
263+
input=[self.total_distance, total_distance],
264+
out=self.total_distance)
255265
layers.sums(input=[self.seq_num, seq_num], out=self.seq_num)
256-
self.metrics.append(sum_error)
266+
layers.sums(input=[self.seq_error, seq_error_count], out=self.seq_error)
267+
self.metrics.append(total_distance)
268+
self.metrics.append(seq_error_count)
257269

258270
def eval(self, executor, eval_program=None):
259271
if eval_program is None:
260272
eval_program = Program()
261273
block = eval_program.current_block()
262274
with program_guard(main_program=eval_program):
263-
total_error = _clone_var_(block, self.total_error)
275+
total_distance = _clone_var_(block, self.total_distance)
264276
seq_num = _clone_var_(block, self.seq_num)
277+
seq_error = _clone_var_(block, self.seq_error)
265278
seq_num = layers.cast(x=seq_num, dtype='float32')
266-
out = layers.elementwise_div(x=total_error, y=seq_num)
267-
return np.array(executor.run(eval_program, fetch_list=[out])[0])
279+
seq_error = layers.cast(x=seq_error, dtype='float32')
280+
avg_distance = layers.elementwise_div(x=total_distance, y=seq_num)
281+
avg_seq_error = layers.elementwise_div(x=seq_error, y=seq_num)
282+
result = executor.run(eval_program,
283+
fetch_list=[avg_distance, avg_seq_error])
284+
return np.array(result[0]), np.array(result[1])

0 commit comments

Comments
 (0)