Skip to content

Commit eafa066

Browse files
authored
1 parent b456110 commit eafa066

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

paddlenlp/transformers/llama/modeling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -417,8 +417,8 @@ def _set_cos_sin_cache(self, seq_len):
417417

418418
def forward(self, x, seq_len=None):
419419
# x: [bs, num_attention_heads, seq_len, head_size]
420-
cos = self.cos_cached[:, :, :seq_len, ...]
421-
sin = self.sin_cached[:, :, :seq_len, ...]
420+
cos = self.cos_cached[:, :seq_len, :, :]
421+
sin = self.sin_cached[:, :seq_len, :, :]
422422
return (
423423
cos.cast(x.dtype) if cos.dtype != x.dtype else cos,
424424
sin.cast(x.dtype) if sin.dtype != x.dtype else sin,

0 commit comments

Comments
 (0)