Skip to content

Commit fa254cf

Browse files
committed
Llama rotary dims from 4 to 2
1 parent 2bdbf2d commit fa254cf

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/petals/models/llama/block.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,8 @@ def forward(
8585
if past_key_value is not None:
8686
kv_seq_len += past_key_value[0].shape[-2]
8787
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
88-
cos = cos[:, :, kv_seq_len - q_len :]
89-
sin = sin[:, :, kv_seq_len - q_len :]
88+
cos = cos[kv_seq_len - q_len :]
89+
sin = sin[kv_seq_len - q_len :]
9090

9191
if q_len == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
9292
query_states, key_states = self._optimized_apply_rotary(query_states, key_states, cos, sin)

0 commit comments

Comments
 (0)