Skip to content

Commit 6414565

Browse files
author
gongel
committed
fix: remove .numpy() to support static graph
1 parent ac54792 commit 6414565

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

paddlenlp/transformers/transformer/modeling.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -931,7 +931,7 @@ def grow_topk(i, logits, alive_seq, alive_log_probs, states):
931931

932932
# Length penalty is given by = (5+len(decode)/6) ^ -\alpha. Pls refer to
933933
# https://arxiv.org/abs/1609.08144.
934-
length_penalty = np.power(5.0 + (i.numpy()[0] + 1.0) / 6.0, alpha)
934+
length_penalty = paddle.pow(5.0 + (i + 1.0) / 6.0, alpha)
935935
curr_scores = log_probs / length_penalty
936936
flat_curr_scores = paddle.reshape(curr_scores, [batch_size, -1])
937937

@@ -1038,10 +1038,9 @@ def inner_loop(i, trg_word, alive_seq, alive_log_probs, finished_seq,
10381038

10391039
def is_not_finish(i, trg_word, alive_seq, alive_log_probs, finished_seq,
10401040
finished_scores, finished_flags, caches):
1041-
return paddle.to_tensor(
1042-
i.numpy()[0] < max_len and
1043-
not (early_finish(alive_log_probs, finished_scores,
1044-
finished_flags).numpy()[0]))
1041+
return paddle.greater_than(
1042+
i < max_len,
1043+
early_finish(alive_log_probs, finished_scores, finished_flags))
10451044

10461045
_, trg_word, alive_seq, alive_log_probs, finished_seq, finished_scores, finished_flags, caches = paddle.static.nn.while_loop(
10471046
is_not_finish,

0 commit comments

Comments
 (0)