Skip to content

Commit 0fb51f9

Browse files
committed
Fix RoPE inputs in Attention.forward
1 parent aed2b0e commit 0fb51f9

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

scaled_dot_product_attention.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,14 @@ def forward(
3838
value_states = self.v_proj(hidden_states).view(hidden_shape)
3939

4040
cos_table, sin_table = position_embeddings
41-
sin_table = sin_table[0]
42-
cos_table = cos_table[0]
41+
sin_table = sin_table[0, ..., sin_table.shape[-1] // 2 :]
42+
cos_table = cos_table[0, ..., cos_table.shape[-1] // 2 :]
4343

4444
query_states = type(self).rotary_position_embedding(
45-
query_states, sin_table, cos_table
45+
query_states, sin_table, cos_table, interleaved=False
4646
)
4747
key_states = type(self).rotary_position_embedding(
48-
key_states, sin_table, cos_table
48+
key_states, sin_table, cos_table, interleaved=False
4949
)
5050

5151
query_states = query_states.transpose(1, 2)

0 commit comments

Comments
 (0)