We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent b456110 commit eafa066Copy full SHA for eafa066
paddlenlp/transformers/llama/modeling.py
@@ -417,8 +417,8 @@ def _set_cos_sin_cache(self, seq_len):
417
418
def forward(self, x, seq_len=None):
419
# x: [bs, num_attention_heads, seq_len, head_size]
420
- cos = self.cos_cached[:, :, :seq_len, ...]
421
- sin = self.sin_cached[:, :, :seq_len, ...]
+ cos = self.cos_cached[:, :seq_len, :, :]
+ sin = self.sin_cached[:, :seq_len, :, :]
422
return (
423
cos.cast(x.dtype) if cos.dtype != x.dtype else cos,
424
sin.cast(x.dtype) if sin.dtype != x.dtype else sin,
0 commit comments