Skip to content

Commit f2dbd81

Browse files
authored
GPT: fix batch predict. (#411)
1 parent c7914e9 commit f2dbd81

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

paddlenlp/transformers/gpt/modeling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -781,14 +781,14 @@ def model(self,
781781
def forward(self, input_ids, end_id):
782782
output, cached_kvs = self.model(input_ids, use_cache=True, cache=None)
783783
src_ids = input_ids
784-
nid = paddle.argmax(output[0, -1]).reshape([1, -1])
784+
nid = paddle.argmax(output[:, -1, :], axis=-1).reshape([-1, 1])
785785
src_ids = paddle.concat([src_ids, nid], axis=1)
786786
cur_len = 0
787787
while (cur_len < self.max_predict_len):
788788
output, cached_kvs = self.model(
789789
nid, use_cache=True, cache=cached_kvs)
790790

791-
nid = paddle.argmax(output[0, -1]).reshape([1, -1])
791+
nid = paddle.argmax(output[:, -1, :], axis=-1).reshape([-1, 1])
792792
src_ids = paddle.concat([src_ids, nid], axis=1)
793793
cur_len += 1
794794
if paddle.max(nid) == end_id:

0 commit comments

Comments
 (0)