Skip to content

Commit 51d27f4

Browse files
committed
implement position encoding for shifted tokens
In AttentionSink, it uses tokens' positions in the KVCache instead of the actual text. When tokens get shifted in KVCache, it needs to update q and k's position embedding. In the original [implementation](https://github.com/mit-han-lab/streaming-llm) of AttentionSink with Rope, it caches the original q and k in KVCache and apply position embedding during inference. This PR adds `RopeWithAttentionSink`. It assumes that q and k are already encoded with their original position. When we shift tokens, we reapply the position delta. This has two benefits: - minimize our code since our existing `llama_transformer` applies rope embedding before doing KVCache update - avoid performance regression when tokens are not shifted because we don't need to reapply position encoding in KVCache for them Differential Revision: [D65366440](https://our.internmc.facebook.com/intern/diff/D65366440/) [ghstack-poisoned]
1 parent 7140dec commit 51d27f4

File tree

2 files changed

+154
-0
lines changed

2 files changed

+154
-0
lines changed

examples/models/llama/source_transformation/attention_sink.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212
import torch
1313

14+
from executorch.examples.models.llama.llama_transformer import Rope
15+
1416
from torch import nn
1517

1618

@@ -112,3 +114,46 @@ def update(
112114
narrowed_k.copy_(k_val)
113115
narrowed_v.copy_(v_val)
114116
return self.k_cache, self.v_cache
117+
118+
119+
class RopeWithAttentionSink(nn.Module):
120+
"""
121+
Rope that helps adjust position encoding when tokens are shifted in KVCache.
122+
For AttentionSink, when tokens are shifted in KVCache, we need to use positions
123+
in KVCache instead of positions in the actual text.
124+
"""
125+
126+
def __init__(self, rope: Rope):
127+
super().__init__()
128+
self.rope = rope
129+
130+
def forward(
131+
self,
132+
q: torch.Tensor,
133+
k: torch.Tensor,
134+
original_position: int,
135+
new_position: int,
136+
seq_len: int,
137+
):
138+
"""
139+
Rerotate keys from original_position to new_position. This is done by rerotating
140+
keys with (new_position * theta - original_position * theta) with the following matrix:
141+
(cos(delta), -sin(delta)
142+
sin(delta), cos(delta))
143+
where delta = new_position * theta - original_position * theta
144+
145+
Based on https://github.com/huggingface/transformers/blame/main/src/transformers/cache_utils.py#L961
146+
"""
147+
original_freqs_cos = self.rope.freqs_cos.narrow(0, original_position, seq_len)
148+
original_freqs_sin = self.rope.freqs_sin.narrow(0, original_position, seq_len)
149+
new_freqs_cos = self.rope.freqs_cos.narrow(0, new_position, seq_len)
150+
new_freqs_sin = self.rope.freqs_sin.narrow(0, new_position, seq_len)
151+
rerotation_cos = (
152+
new_freqs_cos * original_freqs_cos + new_freqs_sin * original_freqs_sin
153+
)
154+
rerotation_sin = (
155+
new_freqs_sin * original_freqs_cos - new_freqs_cos * original_freqs_sin
156+
)
157+
158+
q, k = self.rope.apply_rotary_emb(q, k, rerotation_cos, rerotation_sin)
159+
return q, k

examples/models/llama/source_transformation/test_attention_sink.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
import unittest
88

99
import torch
10+
from executorch.examples.models.llama.llama_transformer import ModelArgs, Rope
1011

1112
from executorch.examples.models.llama.source_transformation.attention_sink import (
1213
KVCacheWithAttentionSink,
14+
RopeWithAttentionSink,
1315
)
1416

1517

@@ -176,3 +178,110 @@ def test_update_with_all_shift(self):
176178

177179
torch.testing.assert_close(k_out, expected_k_out)
178180
torch.testing.assert_close(v_out, expected_v_out)
181+
182+
183+
class RopeWithAttentionSinkTest(unittest.TestCase):
184+
185+
def setUp(self):
186+
self.params = ModelArgs(use_kv_cache=True, enable_dynamic_shape=True)
187+
self.rope = Rope(self.params)
188+
self.rope_with_attention_sink = RopeWithAttentionSink(rope=self.rope)
189+
self.seq_len = 32
190+
self.n_local_heads = self.params.n_heads
191+
self.n_local_kv_heads = self.params.n_heads
192+
self.head_dim = self.params.dim // self.params.n_heads
193+
self.q = torch.ones(
194+
(1, self.seq_len, self.n_local_heads, self.head_dim), dtype=torch.float32
195+
)
196+
self.k = torch.full(
197+
(1, self.seq_len, self.n_local_kv_heads, self.head_dim),
198+
2,
199+
dtype=torch.float32,
200+
)
201+
202+
def test_rotate_backward(self):
203+
original_position = 128
204+
new_position = 127
205+
206+
pre_rotated_q, pre_rotated_k = self.rope.forward(
207+
q=self.q,
208+
k=self.k,
209+
seq_len=self.seq_len,
210+
input_pos=torch.tensor([original_position], dtype=torch.int32),
211+
)
212+
213+
q, k = self.rope_with_attention_sink.forward(
214+
q=pre_rotated_q,
215+
k=pre_rotated_k,
216+
original_position=original_position,
217+
new_position=new_position,
218+
seq_len=self.seq_len,
219+
)
220+
221+
expected_q, expected_k = self.rope.forward(
222+
q=self.q,
223+
k=self.k,
224+
seq_len=self.seq_len,
225+
input_pos=torch.tensor([new_position], dtype=torch.int32),
226+
)
227+
228+
torch.testing.assert_close(q, expected_q)
229+
torch.testing.assert_close(k, expected_k)
230+
231+
def test_rotate_inplace(self):
232+
original_position = 128
233+
new_position = 128
234+
235+
pre_rotated_q, pre_rotated_k = self.rope.forward(
236+
q=self.q,
237+
k=self.k,
238+
seq_len=self.seq_len,
239+
input_pos=torch.tensor([original_position], dtype=torch.int32),
240+
)
241+
242+
q, k = self.rope_with_attention_sink.forward(
243+
q=pre_rotated_q,
244+
k=pre_rotated_k,
245+
original_position=original_position,
246+
new_position=new_position,
247+
seq_len=self.seq_len,
248+
)
249+
250+
expected_q, expected_k = self.rope.forward(
251+
q=self.q,
252+
k=self.k,
253+
seq_len=self.seq_len,
254+
input_pos=torch.tensor([new_position], dtype=torch.int32),
255+
)
256+
257+
torch.testing.assert_close(q, expected_q)
258+
torch.testing.assert_close(k, expected_k)
259+
260+
def test_rotate_forward(self):
261+
original_position = 128
262+
new_position = 129
263+
264+
pre_rotated_q, pre_rotated_k = self.rope.forward(
265+
q=self.q,
266+
k=self.k,
267+
seq_len=self.seq_len,
268+
input_pos=torch.tensor([original_position], dtype=torch.int32),
269+
)
270+
271+
q, k = self.rope_with_attention_sink.forward(
272+
q=pre_rotated_q,
273+
k=pre_rotated_k,
274+
original_position=original_position,
275+
new_position=new_position,
276+
seq_len=self.seq_len,
277+
)
278+
279+
expected_q, expected_k = self.rope.forward(
280+
q=self.q,
281+
k=self.k,
282+
seq_len=self.seq_len,
283+
input_pos=torch.tensor([new_position], dtype=torch.int32),
284+
)
285+
286+
torch.testing.assert_close(q, expected_q)
287+
torch.testing.assert_close(k, expected_k)

0 commit comments

Comments
 (0)