Skip to content

Commit acbda44

Browse files
Merge pull request #8365 from wanghaoshuang/seq_error
Add sequence error output to edit distance evaluator
2 parents 261a12a + 8d57e9c commit acbda44

File tree

2 files changed

+37
-21
lines changed

2 files changed

+37
-21
lines changed

python/paddle/fluid/evaluator.py

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
__all__ = [
2323
'Accuracy',
2424
'ChunkEvaluator',
25+
'EditDistance',
2526
]
2627

2728

@@ -211,7 +212,7 @@ def eval(self, executor, eval_program=None):
211212
class EditDistance(Evaluator):
212213
"""
213214
Accumulate edit distance sum and sequence number from mini-batches and
214-
compute the average edit_distance of all batches.
215+
compute the average edit_distance and instance error of all batches.
215216
216217
Args:
217218
input: the sequences predicted by network.
@@ -227,14 +228,12 @@ class EditDistance(Evaluator):
227228
for epoch in PASS_NUM:
228229
distance_evaluator.reset(exe)
229230
for data in batches:
230-
loss, sum_distance = exe.run(fetch_list=[cost] + distance_evaluator.metrics)
231-
avg_distance = distance_evaluator.eval(exe)
232-
pass_distance = distance_evaluator.eval(exe)
231+
loss = exe.run(fetch_list=[cost])
232+
distance, instance_error = distance_evaluator.eval(exe)
233233
234234
In the above example:
235-
'sum_distance' is the sum of the batch's edit distance.
236-
'avg_distance' is the average of edit distance from the firt batch to the current batch.
237-
'pass_distance' is the average of edit distance from all the pass.
235+
'distance' is the average of the edit distance in a pass.
236+
'instance_error' is the instance error rate in a pass.
238237
239238
"""
240239

@@ -244,25 +243,45 @@ def __init__(self, input, label, ignored_tokens=None, **kwargs):
244243
if main_program.current_block().idx != 0:
245244
raise ValueError("You can only invoke Evaluator in root block")
246245

247-
self.total_error = self.create_state(
248-
dtype='float32', shape=[1], suffix='total_error')
246+
self.total_distance = self.create_state(
247+
dtype='float32', shape=[1], suffix='total_distance')
249248
self.seq_num = self.create_state(
250249
dtype='int64', shape=[1], suffix='seq_num')
251-
error, seq_num = layers.edit_distance(
250+
self.instance_error = self.create_state(
251+
dtype='int64', shape=[1], suffix='instance_error')
252+
distances, seq_num = layers.edit_distance(
252253
input=input, label=label, ignored_tokens=ignored_tokens)
253-
#error = layers.cast(x=error, dtype='float32')
254-
sum_error = layers.reduce_sum(error)
255-
layers.sums(input=[self.total_error, sum_error], out=self.total_error)
254+
255+
zero = layers.fill_constant(shape=[1], value=0.0, dtype='float32')
256+
compare_result = layers.equal(distances, zero)
257+
compare_result_int = layers.cast(x=compare_result, dtype='int')
258+
seq_right_count = layers.reduce_sum(compare_result_int)
259+
instance_error_count = layers.elementwise_sub(
260+
x=seq_num, y=seq_right_count)
261+
total_distance = layers.reduce_sum(distances)
262+
layers.sums(
263+
input=[self.total_distance, total_distance],
264+
out=self.total_distance)
256265
layers.sums(input=[self.seq_num, seq_num], out=self.seq_num)
257-
self.metrics.append(sum_error)
266+
layers.sums(
267+
input=[self.instance_error, instance_error_count],
268+
out=self.instance_error)
269+
self.metrics.append(total_distance)
270+
self.metrics.append(instance_error_count)
258271

259272
def eval(self, executor, eval_program=None):
260273
if eval_program is None:
261274
eval_program = Program()
262275
block = eval_program.current_block()
263276
with program_guard(main_program=eval_program):
264-
total_error = _clone_var_(block, self.total_error)
277+
total_distance = _clone_var_(block, self.total_distance)
265278
seq_num = _clone_var_(block, self.seq_num)
279+
instance_error = _clone_var_(block, self.instance_error)
266280
seq_num = layers.cast(x=seq_num, dtype='float32')
267-
out = layers.elementwise_div(x=total_error, y=seq_num)
268-
return np.array(executor.run(eval_program, fetch_list=[out])[0])
281+
instance_error = layers.cast(x=instance_error, dtype='float32')
282+
avg_distance = layers.elementwise_div(x=total_distance, y=seq_num)
283+
avg_instance_error = layers.elementwise_div(
284+
x=instance_error, y=seq_num)
285+
result = executor.run(
286+
eval_program, fetch_list=[avg_distance, avg_instance_error])
287+
return np.array(result[0]), np.array(result[1])

python/paddle/fluid/layers/nn.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2479,10 +2479,7 @@ def __check_input(x, y):
24792479
return out
24802480

24812481

2482-
def edit_distance(input,
2483-
label,
2484-
normalized=False,
2485-
ignored_tokens=None,
2482+
def edit_distance(input, label, normalized=True, ignored_tokens=None,
24862483
name=None):
24872484
"""
24882485
EditDistance operator computes the edit distances between a batch of

0 commit comments

Comments
 (0)