@@ -931,7 +931,7 @@ def grow_topk(i, logits, alive_seq, alive_log_probs, states):
931
931
932
932
# Length penalty is given by = (5+len(decode)/6) ^ -\alpha. Pls refer to
933
933
# https://arxiv.org/abs/1609.08144.
934
- length_penalty = np . power (5.0 + (i . numpy ()[ 0 ] + 1.0 ) / 6.0 , alpha )
934
+ length_penalty = paddle . pow (5.0 + (i + 1.0 ) / 6.0 , alpha )
935
935
curr_scores = log_probs / length_penalty
936
936
flat_curr_scores = paddle .reshape (curr_scores , [batch_size , - 1 ])
937
937
@@ -1038,10 +1038,9 @@ def inner_loop(i, trg_word, alive_seq, alive_log_probs, finished_seq,
1038
1038
1039
1039
def is_not_finish (i , trg_word , alive_seq , alive_log_probs , finished_seq ,
1040
1040
finished_scores , finished_flags , caches ):
1041
- return paddle .to_tensor (
1042
- i .numpy ()[0 ] < max_len and
1043
- not (early_finish (alive_log_probs , finished_scores ,
1044
- finished_flags ).numpy ()[0 ]))
1041
+ return paddle .greater_than (
1042
+ i < max_len ,
1043
+ early_finish (alive_log_probs , finished_scores , finished_flags ))
1045
1044
1046
1045
_ , trg_word , alive_seq , alive_log_probs , finished_seq , finished_scores , finished_flags , caches = paddle .static .nn .while_loop (
1047
1046
is_not_finish ,
0 commit comments