Skip to content

Commit fad921e

Browse files
committed
Update on "implement position encoding for shifted tokens"
In AttentionSink, it uses tokens' positions in the KVCache instead of the actual text. When tokens get shifted in KVCache, it needs to update q and k's position embedding. In the original [implementation](https://github.com/mit-han-lab/streaming-llm) of AttentionSink with Rope, it caches the original q and k in KVCache and apply position embedding during inference. This PR adds `RopeWithAttentionSink`. It assumes that q and k are already encoded with their original position. When we shift tokens, we reapply the position delta. This has two benefits: - minimize our code since our existing `llama_transformer` applies rope embedding before doing KVCache update - avoid performance regression when tokens are not shifted because we don't need to reapply position encoding in KVCache for them Differential Revision: [D65366440](https://our.internmc.facebook.com/intern/diff/D65366440/) [ghstack-poisoned]
1 parent 7692c67 commit fad921e

File tree

2 files changed

+150
-150
lines changed

2 files changed

+150
-150
lines changed

examples/models/llama/source_transformation/attention_sink.py

Lines changed: 43 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,49 @@
1616
from torch import nn
1717

1818

19+
class RopeWithAttentionSink(nn.Module):
20+
"""
21+
Rope that helps adjust position encoding when tokens are shifted in KVCache.
22+
For AttentionSink, when tokens are shifted in KVCache, we need to use positions
23+
in KVCache instead of positions in the actual text.
24+
"""
25+
26+
def __init__(self, rope: Rope):
27+
super().__init__()
28+
self.rope = rope
29+
30+
def forward(
31+
self,
32+
q: torch.Tensor,
33+
k: torch.Tensor,
34+
original_position: int,
35+
new_position: int,
36+
seq_len: int,
37+
):
38+
"""
39+
Rerotate q and k from original_position to new_position. This is done by rerotating q
40+
and k with (new_position * theta - original_position * theta) with the following matrix:
41+
(cos(delta), -sin(delta)
42+
sin(delta), cos(delta))
43+
where delta = new_position * theta - original_position * theta
44+
45+
Based on https://github.com/huggingface/transformers/blame/main/src/transformers/cache_utils.py#L961
46+
"""
47+
original_freqs_cos = self.rope.freqs_cos.narrow(0, original_position, seq_len)
48+
original_freqs_sin = self.rope.freqs_sin.narrow(0, original_position, seq_len)
49+
new_freqs_cos = self.rope.freqs_cos.narrow(0, new_position, seq_len)
50+
new_freqs_sin = self.rope.freqs_sin.narrow(0, new_position, seq_len)
51+
rerotation_cos = (
52+
new_freqs_cos * original_freqs_cos + new_freqs_sin * original_freqs_sin
53+
)
54+
rerotation_sin = (
55+
new_freqs_sin * original_freqs_cos - new_freqs_cos * original_freqs_sin
56+
)
57+
58+
q, k = self.rope.apply_rotary_emb(q, k, rerotation_cos, rerotation_sin)
59+
return q, k
60+
61+
1962
class KVCacheWithAttentionSink(nn.Module):
2063
"""
2164
KV cache that supports attention sink. It keeps the initial few tokens as attention sink.
@@ -114,46 +157,3 @@ def update(
114157
narrowed_k.copy_(k_val)
115158
narrowed_v.copy_(v_val)
116159
return self.k_cache, self.v_cache
117-
118-
119-
class RopeWithAttentionSink(nn.Module):
120-
"""
121-
Rope that helps adjust position encoding when tokens are shifted in KVCache.
122-
For AttentionSink, when tokens are shifted in KVCache, we need to use positions
123-
in KVCache instead of positions in the actual text.
124-
"""
125-
126-
def __init__(self, rope: Rope):
127-
super().__init__()
128-
self.rope = rope
129-
130-
def forward(
131-
self,
132-
q: torch.Tensor,
133-
k: torch.Tensor,
134-
original_position: int,
135-
new_position: int,
136-
seq_len: int,
137-
):
138-
"""
139-
Rerotate q and k from original_position to new_position. This is done by rerotating q
140-
and k with (new_position * theta - original_position * theta) with the following matrix:
141-
(cos(delta), -sin(delta)
142-
sin(delta), cos(delta))
143-
where delta = new_position * theta - original_position * theta
144-
145-
Based on https://github.com/huggingface/transformers/blame/main/src/transformers/cache_utils.py#L961
146-
"""
147-
original_freqs_cos = self.rope.freqs_cos.narrow(0, original_position, seq_len)
148-
original_freqs_sin = self.rope.freqs_sin.narrow(0, original_position, seq_len)
149-
new_freqs_cos = self.rope.freqs_cos.narrow(0, new_position, seq_len)
150-
new_freqs_sin = self.rope.freqs_sin.narrow(0, new_position, seq_len)
151-
rerotation_cos = (
152-
new_freqs_cos * original_freqs_cos + new_freqs_sin * original_freqs_sin
153-
)
154-
rerotation_sin = (
155-
new_freqs_sin * original_freqs_cos - new_freqs_cos * original_freqs_sin
156-
)
157-
158-
q, k = self.rope.apply_rotary_emb(q, k, rerotation_cos, rerotation_sin)
159-
return q, k

examples/models/llama/source_transformation/test_attention_sink.py

Lines changed: 107 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,113 @@
1515
)
1616

1717

18+
class RopeWithAttentionSinkTest(unittest.TestCase):
19+
20+
def setUp(self):
21+
self.params = ModelArgs(use_kv_cache=True, enable_dynamic_shape=True)
22+
self.rope = Rope(self.params)
23+
self.rope_with_attention_sink = RopeWithAttentionSink(rope=self.rope)
24+
self.seq_len = 32
25+
self.n_local_heads = self.params.n_heads
26+
self.n_local_kv_heads = self.params.n_heads
27+
self.head_dim = self.params.dim // self.params.n_heads
28+
self.q = torch.ones(
29+
(1, self.seq_len, self.n_local_heads, self.head_dim), dtype=torch.float32
30+
)
31+
self.k = torch.full(
32+
(1, self.seq_len, self.n_local_kv_heads, self.head_dim),
33+
2,
34+
dtype=torch.float32,
35+
)
36+
37+
def test_rotate_backward(self):
38+
original_position = 128
39+
new_position = 127
40+
41+
pre_rotated_q, pre_rotated_k = self.rope.forward(
42+
q=self.q,
43+
k=self.k,
44+
seq_len=self.seq_len,
45+
input_pos=torch.tensor([original_position], dtype=torch.int32),
46+
)
47+
48+
q, k = self.rope_with_attention_sink.forward(
49+
q=pre_rotated_q,
50+
k=pre_rotated_k,
51+
original_position=original_position,
52+
new_position=new_position,
53+
seq_len=self.seq_len,
54+
)
55+
56+
expected_q, expected_k = self.rope.forward(
57+
q=self.q,
58+
k=self.k,
59+
seq_len=self.seq_len,
60+
input_pos=torch.tensor([new_position], dtype=torch.int32),
61+
)
62+
63+
torch.testing.assert_close(q, expected_q)
64+
torch.testing.assert_close(k, expected_k)
65+
66+
def test_rotate_inplace(self):
67+
original_position = 128
68+
new_position = 128
69+
70+
pre_rotated_q, pre_rotated_k = self.rope.forward(
71+
q=self.q,
72+
k=self.k,
73+
seq_len=self.seq_len,
74+
input_pos=torch.tensor([original_position], dtype=torch.int32),
75+
)
76+
77+
q, k = self.rope_with_attention_sink.forward(
78+
q=pre_rotated_q,
79+
k=pre_rotated_k,
80+
original_position=original_position,
81+
new_position=new_position,
82+
seq_len=self.seq_len,
83+
)
84+
85+
expected_q, expected_k = self.rope.forward(
86+
q=self.q,
87+
k=self.k,
88+
seq_len=self.seq_len,
89+
input_pos=torch.tensor([new_position], dtype=torch.int32),
90+
)
91+
92+
torch.testing.assert_close(q, expected_q)
93+
torch.testing.assert_close(k, expected_k)
94+
95+
def test_rotate_forward(self):
96+
original_position = 128
97+
new_position = 129
98+
99+
pre_rotated_q, pre_rotated_k = self.rope.forward(
100+
q=self.q,
101+
k=self.k,
102+
seq_len=self.seq_len,
103+
input_pos=torch.tensor([original_position], dtype=torch.int32),
104+
)
105+
106+
q, k = self.rope_with_attention_sink.forward(
107+
q=pre_rotated_q,
108+
k=pre_rotated_k,
109+
original_position=original_position,
110+
new_position=new_position,
111+
seq_len=self.seq_len,
112+
)
113+
114+
expected_q, expected_k = self.rope.forward(
115+
q=self.q,
116+
k=self.k,
117+
seq_len=self.seq_len,
118+
input_pos=torch.tensor([new_position], dtype=torch.int32),
119+
)
120+
121+
torch.testing.assert_close(q, expected_q)
122+
torch.testing.assert_close(k, expected_k)
123+
124+
18125
class KVCacheWithAttentionSinkTest(unittest.TestCase):
19126

20127
def _init_cache(self):
@@ -178,110 +285,3 @@ def test_update_with_all_shift(self):
178285

179286
torch.testing.assert_close(k_out, expected_k_out)
180287
torch.testing.assert_close(v_out, expected_v_out)
181-
182-
183-
class RopeWithAttentionSinkTest(unittest.TestCase):
184-
185-
def setUp(self):
186-
self.params = ModelArgs(use_kv_cache=True, enable_dynamic_shape=True)
187-
self.rope = Rope(self.params)
188-
self.rope_with_attention_sink = RopeWithAttentionSink(rope=self.rope)
189-
self.seq_len = 32
190-
self.n_local_heads = self.params.n_heads
191-
self.n_local_kv_heads = self.params.n_heads
192-
self.head_dim = self.params.dim // self.params.n_heads
193-
self.q = torch.ones(
194-
(1, self.seq_len, self.n_local_heads, self.head_dim), dtype=torch.float32
195-
)
196-
self.k = torch.full(
197-
(1, self.seq_len, self.n_local_kv_heads, self.head_dim),
198-
2,
199-
dtype=torch.float32,
200-
)
201-
202-
def test_rotate_backward(self):
203-
original_position = 128
204-
new_position = 127
205-
206-
pre_rotated_q, pre_rotated_k = self.rope.forward(
207-
q=self.q,
208-
k=self.k,
209-
seq_len=self.seq_len,
210-
input_pos=torch.tensor([original_position], dtype=torch.int32),
211-
)
212-
213-
q, k = self.rope_with_attention_sink.forward(
214-
q=pre_rotated_q,
215-
k=pre_rotated_k,
216-
original_position=original_position,
217-
new_position=new_position,
218-
seq_len=self.seq_len,
219-
)
220-
221-
expected_q, expected_k = self.rope.forward(
222-
q=self.q,
223-
k=self.k,
224-
seq_len=self.seq_len,
225-
input_pos=torch.tensor([new_position], dtype=torch.int32),
226-
)
227-
228-
torch.testing.assert_close(q, expected_q)
229-
torch.testing.assert_close(k, expected_k)
230-
231-
def test_rotate_inplace(self):
232-
original_position = 128
233-
new_position = 128
234-
235-
pre_rotated_q, pre_rotated_k = self.rope.forward(
236-
q=self.q,
237-
k=self.k,
238-
seq_len=self.seq_len,
239-
input_pos=torch.tensor([original_position], dtype=torch.int32),
240-
)
241-
242-
q, k = self.rope_with_attention_sink.forward(
243-
q=pre_rotated_q,
244-
k=pre_rotated_k,
245-
original_position=original_position,
246-
new_position=new_position,
247-
seq_len=self.seq_len,
248-
)
249-
250-
expected_q, expected_k = self.rope.forward(
251-
q=self.q,
252-
k=self.k,
253-
seq_len=self.seq_len,
254-
input_pos=torch.tensor([new_position], dtype=torch.int32),
255-
)
256-
257-
torch.testing.assert_close(q, expected_q)
258-
torch.testing.assert_close(k, expected_k)
259-
260-
def test_rotate_forward(self):
261-
original_position = 128
262-
new_position = 129
263-
264-
pre_rotated_q, pre_rotated_k = self.rope.forward(
265-
q=self.q,
266-
k=self.k,
267-
seq_len=self.seq_len,
268-
input_pos=torch.tensor([original_position], dtype=torch.int32),
269-
)
270-
271-
q, k = self.rope_with_attention_sink.forward(
272-
q=pre_rotated_q,
273-
k=pre_rotated_k,
274-
original_position=original_position,
275-
new_position=new_position,
276-
seq_len=self.seq_len,
277-
)
278-
279-
expected_q, expected_k = self.rope.forward(
280-
q=self.q,
281-
k=self.k,
282-
seq_len=self.seq_len,
283-
input_pos=torch.tensor([new_position], dtype=torch.int32),
284-
)
285-
286-
torch.testing.assert_close(q, expected_q)
287-
torch.testing.assert_close(k, expected_k)

0 commit comments

Comments
 (0)