Skip to content

Commit 7692c67

Browse files
committed
Update on "implement position encoding for shifted tokens"
In AttentionSink, it uses tokens' positions in the KVCache instead of the actual text. When tokens get shifted in KVCache, it needs to update q and k's position embedding. In the original [implementation](https://github.com/mit-han-lab/streaming-llm) of AttentionSink with Rope, it caches the original q and k in KVCache and apply position embedding during inference. This PR adds `RopeWithAttentionSink`. It assumes that q and k are already encoded with their original position. When we shift tokens, we reapply the position delta. This has two benefits: - minimize our code since our existing `llama_transformer` applies rope embedding before doing KVCache update - avoid performance regression when tokens are not shifted because we don't need to reapply position encoding in KVCache for them Differential Revision: [D65366440](https://our.internmc.facebook.com/intern/diff/D65366440/) [ghstack-poisoned]
1 parent 51d27f4 commit 7692c67

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

examples/models/llama/source_transformation/attention_sink.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,8 @@ def forward(
136136
seq_len: int,
137137
):
138138
"""
139-
Rerotate keys from original_position to new_position. This is done by rerotating
140-
keys with (new_position * theta - original_position * theta) with the following matrix:
139+
Rerotate q and k from original_position to new_position. This is done by rerotating q
140+
and k with (new_position * theta - original_position * theta) with the following matrix:
141141
(cos(delta), -sin(delta)
142142
sin(delta), cos(delta))
143143
where delta = new_position * theta - original_position * theta

0 commit comments

Comments
 (0)