Skip to content

Commit fc98084

Browse files
author
gongel
committed
fix: use alive_seq when no eos
1 parent 1795f68 commit fc98084

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

paddlenlp/transformers/transformer/modeling.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,6 +1026,21 @@ def grow_finished(finished_seq, finished_scores, finished_flags,
10261026
finished_flags).numpy():
10271027
break
10281028

1029+
# Accounting for corner case: It's possible that no sequence in alive for a
1030+
# particular batch item ever reached EOS. In that case, we should just copy
1031+
# the contents of alive for that batch item. tf.reduce_any(finished_flags, 1)
1032+
# if 0, means that no sequence for that batch index had reached EOS. We need
1033+
# to do the same for the scores as well.
1034+
finished_flags = paddle.any(paddle.cast(
1035+
finished_flags, dtype='bool'),
1036+
axis=1,
1037+
keepdim=True).tile([1, beam_size])
1038+
finished_seq = paddle.where(
1039+
finished_flags.unsqueeze(-1).tile([1, 1, alive_seq.shape[-1]]),
1040+
finished_seq, alive_seq)
1041+
finished_scores = paddle.where(finished_flags, finished_scores,
1042+
alive_log_probs)
1043+
10291044
return finished_seq, finished_scores
10301045

10311046

0 commit comments

Comments
 (0)