Skip to content

Commit 795f572

Browse files
committed
Rename 'seq_error' to 'instance_error'
1 parent 87d90d2 commit 795f572

File tree

1 file changed

+18
-13
lines changed

1 file changed

+18
-13
lines changed

python/paddle/fluid/evaluator.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
__all__ = [
2323
'Accuracy',
2424
'ChunkEvaluator',
25+
'EditDistance',
2526
]
2627

2728

@@ -211,7 +212,7 @@ def eval(self, executor, eval_program=None):
211212
class EditDistance(Evaluator):
212213
"""
213214
Accumulate edit distance sum and sequence number from mini-batches and
214-
compute the average edit_distance of all batches.
215+
compute the average edit_distance and instance error of all batches.
215216
216217
Args:
217218
input: the sequences predicted by network.
@@ -228,11 +229,11 @@ class EditDistance(Evaluator):
228229
distance_evaluator.reset(exe)
229230
for data in batches:
230231
loss = exe.run(fetch_list=[cost])
231-
distance, sequence_error = distance_evaluator.eval(exe)
232+
distance, instance_error = distance_evaluator.eval(exe)
232233
233234
In the above example:
234235
'distance' is the average of the edit distance rate in a pass.
235-
'sequence_error' is the sequence error rate in a pass.
236+
'instance_error' is the instance error rate in a pass.
236237
237238
"""
238239

@@ -246,24 +247,27 @@ def __init__(self, input, label, ignored_tokens=None, **kwargs):
246247
dtype='float32', shape=[1], suffix='total_distance')
247248
self.seq_num = self.create_state(
248249
dtype='int64', shape=[1], suffix='seq_num')
249-
self.seq_error = self.create_state(
250-
dtype='int64', shape=[1], suffix='seq_error')
250+
self.instance_error = self.create_state(
251+
dtype='int64', shape=[1], suffix='instance_error')
251252
distances, seq_num = layers.edit_distance(
252253
input=input, label=label, ignored_tokens=ignored_tokens)
253254

254255
zero = layers.fill_constant(shape=[1], value=0.0, dtype='float32')
255256
compare_result = layers.equal(distances, zero)
256257
compare_result_int = layers.cast(x=compare_result, dtype='int')
257258
seq_right_count = layers.reduce_sum(compare_result_int)
258-
seq_error_count = layers.elementwise_sub(x=seq_num, y=seq_right_count)
259+
instance_error_count = layers.elementwise_sub(
260+
x=seq_num, y=seq_right_count)
259261
total_distance = layers.reduce_sum(distances)
260262
layers.sums(
261263
input=[self.total_distance, total_distance],
262264
out=self.total_distance)
263265
layers.sums(input=[self.seq_num, seq_num], out=self.seq_num)
264-
layers.sums(input=[self.seq_error, seq_error_count], out=self.seq_error)
266+
layers.sums(
267+
input=[self.instance_error, instance_error_count],
268+
out=self.instance_error)
265269
self.metrics.append(total_distance)
266-
self.metrics.append(seq_error_count)
270+
self.metrics.append(instance_error_count)
267271

268272
def eval(self, executor, eval_program=None):
269273
if eval_program is None:
@@ -272,11 +276,12 @@ def eval(self, executor, eval_program=None):
272276
with program_guard(main_program=eval_program):
273277
total_distance = _clone_var_(block, self.total_distance)
274278
seq_num = _clone_var_(block, self.seq_num)
275-
seq_error = _clone_var_(block, self.seq_error)
279+
instance_error = _clone_var_(block, self.instance_error)
276280
seq_num = layers.cast(x=seq_num, dtype='float32')
277-
seq_error = layers.cast(x=seq_error, dtype='float32')
281+
instance_error = layers.cast(x=instance_error, dtype='float32')
278282
avg_distance = layers.elementwise_div(x=total_distance, y=seq_num)
279-
avg_seq_error = layers.elementwise_div(x=seq_error, y=seq_num)
280-
result = executor.run(eval_program,
281-
fetch_list=[avg_distance, avg_seq_error])
283+
avg_instance_error = layers.elementwise_div(
284+
x=instance_error, y=seq_num)
285+
result = executor.run(
286+
eval_program, fetch_list=[avg_distance, avg_instance_error])
282287
return np.array(result[0]), np.array(result[1])

0 commit comments

Comments
 (0)