Skip to content

Commit 47dcd57

Browse files
committed
add attention_sink.py
This PR adds `KVCacheWithAttentionSink`, which is required for `AttentionSink`. It keeps the first `sink_size` tokens as attention sinks and maintains a sliding window with `window_size` for new tokens. Note: I am trying to implement and verify `AttentionSink` in eager mode first. So the current implementation may still have some lower errors or performance issue. For example, it does not support the case when dynamic shape is disabled. Will leave these problems to resolve when we are ready to deploy `AttentionSink` to edge. Differential Revision: [D65235798](https://our.internmc.facebook.com/intern/diff/D65235798/) [ghstack-poisoned]
1 parent f2ad9d0 commit 47dcd57

File tree

3 files changed

+304
-0
lines changed

3 files changed

+304
-0
lines changed

examples/models/llama/TARGETS

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ runtime.python_library(
9292
"source_transformation/sdpa.py",
9393
"source_transformation/spin_quant.py",
9494
"source_transformation/vulkan_rope.py",
95+
"source_transformation/attention_sink.py",
9596
],
9697
_is_external_target = True,
9798
base_module = "executorch.examples.models.llama",
@@ -212,3 +213,14 @@ runtime.python_test(
212213
"//executorch/examples/models/llama:llama_transformer",
213214
],
214215
)
216+
217+
runtime.python_test(
218+
name = "attention_sink_test",
219+
srcs = [
220+
"source_transformation/test_attention_sink.py",
221+
],
222+
deps = [
223+
"//caffe2:torch",
224+
":export_library",
225+
],
226+
)
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Components for supporting Attention Sink. See
8+
# https://arxiv.org/abs/2309.17453 for more details about Attention Sink.
9+
10+
from typing import Tuple
11+
12+
import torch
13+
14+
from torch import nn
15+
16+
17+
class KVCacheWithAttentionSink(nn.Module):
18+
"""
19+
KV cache that supports attention sink. It keeps the initial few tokens as attention sink.
20+
For other tokens, it uses a sliding window to keep the most recent tokens.
21+
22+
Parameters:
23+
window_size: the size of the sliding window
24+
sink_size: the number of initial tokens to keep as attention sink
25+
"""
26+
27+
def __init__(
28+
self,
29+
max_batch_size: int,
30+
window_size: int,
31+
sink_size: int,
32+
n_heads: int,
33+
head_dim: int,
34+
transpose_cache: bool,
35+
dtype=torch.float32,
36+
):
37+
super().__init__()
38+
self.window_size = window_size
39+
self.sink_size = sink_size
40+
self.cache_size = window_size + sink_size
41+
self.is_transposed = transpose_cache
42+
if transpose_cache:
43+
cache_shape = (max_batch_size, n_heads, self.cache_size, head_dim)
44+
else:
45+
cache_shape = (max_batch_size, self.cache_size, n_heads, head_dim)
46+
47+
self.max_batch_size = max_batch_size
48+
self.n_heads = n_heads
49+
self.head_dim = head_dim
50+
self.transpose_cache = transpose_cache
51+
self.register_buffer(
52+
"k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu")
53+
)
54+
self.register_buffer(
55+
"v_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu")
56+
)
57+
58+
def update(
59+
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
60+
) -> Tuple[torch.Tensor, torch.Tensor]:
61+
start_pos = input_pos[0].item()
62+
torch._check_is_size(start_pos)
63+
dim_to_slice = 2 if self.transpose_cache else 1
64+
seq_length = k_val.size(dim_to_slice)
65+
66+
if start_pos + seq_length <= self.cache_size:
67+
# There are still enough spaces in the cache to store the new tokens.
68+
# No need to shift the existing tokens.
69+
# pyre-ignore: Incompatible parameter type [6]
70+
narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length)
71+
# pyre-ignore: Incompatible parameter type [6]
72+
narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length)
73+
74+
narrowed_k.copy_(k_val)
75+
narrowed_v.copy_(v_val)
76+
else:
77+
# There are not enough spaces in the cache to store the new tokens.
78+
# We need to shift the existing tokens.
79+
num_to_evict = min(start_pos + seq_length - self.cache_size, seq_length)
80+
81+
# Shift the existing entries to the left
82+
# pyre-ignore: Incompatible parameter type [6]
83+
k_to_keep = self.k_cache.narrow(
84+
dim_to_slice,
85+
self.sink_size + num_to_evict,
86+
self.window_size - num_to_evict,
87+
).clone()
88+
# pyre-ignore: Incompatible parameter type [6]
89+
v_to_keep = self.v_cache.narrow(
90+
dim_to_slice,
91+
self.sink_size + num_to_evict,
92+
self.window_size - num_to_evict,
93+
).clone()
94+
# pyre-ignore: Incompatible parameter type [6]
95+
k_new_position = self.k_cache.narrow(
96+
dim_to_slice, self.sink_size, self.window_size - num_to_evict
97+
)
98+
# pyre-ignore: Incompatible parameter type [6]
99+
v_new_position = self.v_cache.narrow(
100+
dim_to_slice, self.sink_size, self.window_size - num_to_evict
101+
)
102+
k_new_position.copy_(k_to_keep)
103+
v_new_position.copy_(v_to_keep)
104+
105+
# Appending new entries
106+
narrowed_k = self.k_cache.narrow(
107+
dim_to_slice, self.cache_size - seq_length, seq_length
108+
)
109+
narrowed_v = self.v_cache.narrow(
110+
dim_to_slice, self.cache_size - seq_length, seq_length
111+
)
112+
narrowed_k.copy_(k_val)
113+
narrowed_v.copy_(v_val)
114+
return self.k_cache, self.v_cache
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
11+
from executorch.examples.models.llama.source_transformation.attention_sink import (
12+
KVCacheWithAttentionSink,
13+
)
14+
15+
16+
class KVCacheWithAttentionSinkTest(unittest.TestCase):
17+
18+
def _init_cache(self):
19+
self.kv_cache = KVCacheWithAttentionSink(
20+
max_batch_size=self.max_batch_size,
21+
window_size=self.window_size,
22+
sink_size=self.sink_size,
23+
n_heads=self.n_heads,
24+
head_dim=self.head_dim,
25+
transpose_cache=self.transpose_cache,
26+
dtype=self.dtype,
27+
)
28+
29+
def setUp(self):
30+
torch.manual_seed(42)
31+
self.max_batch_size = 1
32+
self.window_size = 28
33+
self.sink_size = 4
34+
self.n_heads = 8
35+
self.head_dim = 16
36+
self.transpose_cache = False
37+
self.dtype = torch.float32
38+
self._init_cache()
39+
40+
def test_update_empty_cache(self):
41+
# KV cache is empty, update will fill sink tokens
42+
input_pos = torch.tensor([0], dtype=torch.int32)
43+
k = torch.ones((1, 1, 8, 16), dtype=self.dtype)
44+
v = torch.ones((1, 1, 8, 16), dtype=self.dtype)
45+
46+
k_out, v_out = self.kv_cache.update(input_pos, k, v)
47+
48+
expected_k_out = torch.cat(
49+
[
50+
torch.ones((1, 1, 8, 16), dtype=self.dtype),
51+
torch.zeros((1, 31, 8, 16), dtype=self.dtype),
52+
],
53+
dim=1,
54+
)
55+
expected_v_out = torch.cat(
56+
[
57+
torch.ones((1, 1, 8, 16), dtype=self.dtype),
58+
torch.zeros((1, 31, 8, 16), dtype=self.dtype),
59+
],
60+
dim=1,
61+
)
62+
63+
torch.testing.assert_close(k_out, expected_k_out)
64+
torch.testing.assert_close(v_out, expected_v_out)
65+
66+
def test_update_without_shift(self):
67+
# KV cache has enough spaces for new tokens, no shift
68+
input_pos = torch.tensor([0], dtype=torch.int32)
69+
k = torch.ones((1, 5, 8, 16), dtype=self.dtype)
70+
v = torch.ones((1, 5, 8, 16), dtype=self.dtype)
71+
72+
self.kv_cache.update(input_pos, k, v)
73+
74+
input_pos = torch.tensor([5], dtype=torch.int32)
75+
k = torch.full((1, 5, 8, 16), 2, dtype=self.dtype)
76+
v = torch.full((1, 5, 8, 16), 2, dtype=self.dtype)
77+
78+
k_out, v_out = self.kv_cache.update(input_pos, k, v)
79+
80+
expected_k_out = torch.cat(
81+
[
82+
torch.ones((1, 5, 8, 16), dtype=self.dtype),
83+
torch.full((1, 5, 8, 16), 2, dtype=self.dtype),
84+
torch.zeros((1, 22, 8, 16), dtype=self.dtype),
85+
],
86+
dim=1,
87+
)
88+
expected_v_out = torch.cat(
89+
[
90+
torch.ones((1, 5, 8, 16), dtype=self.dtype),
91+
torch.full((1, 5, 8, 16), 2, dtype=self.dtype),
92+
torch.zeros((1, 22, 8, 16), dtype=self.dtype),
93+
],
94+
dim=1,
95+
)
96+
97+
torch.testing.assert_close(k_out, expected_k_out)
98+
torch.testing.assert_close(v_out, expected_v_out)
99+
100+
def test_update_with_some_shift(self):
101+
# KV cache has some spaces for new tokens but not all, shift some tokens
102+
input_pos = torch.tensor([0], dtype=torch.int32)
103+
k = torch.ones((1, 5, 8, 16), dtype=self.dtype)
104+
v = torch.ones((1, 5, 8, 16), dtype=self.dtype)
105+
106+
self.kv_cache.update(input_pos, k, v)
107+
108+
input_pos = torch.tensor([5], dtype=torch.int32)
109+
k = torch.full((1, 5, 8, 16), 2, dtype=self.dtype)
110+
v = torch.full((1, 5, 8, 16), 2, dtype=self.dtype)
111+
112+
self.kv_cache.update(input_pos, k, v)
113+
114+
input_pos = torch.tensor([10], dtype=torch.int32)
115+
k = torch.full((1, 24, 8, 16), 3, dtype=self.dtype)
116+
v = torch.full((1, 24, 8, 16), 3, dtype=self.dtype)
117+
118+
k_out, v_out = self.kv_cache.update(input_pos, k, v)
119+
120+
expected_k_out = torch.cat(
121+
[
122+
torch.ones((1, 4, 8, 16), dtype=self.dtype),
123+
torch.full((1, 4, 8, 16), 2, dtype=self.dtype),
124+
torch.full((1, 24, 8, 16), 3, dtype=self.dtype),
125+
],
126+
dim=1,
127+
)
128+
expected_v_out = torch.cat(
129+
[
130+
torch.ones((1, 4, 8, 16), dtype=self.dtype),
131+
torch.full((1, 4, 8, 16), 2, dtype=self.dtype),
132+
torch.full((1, 24, 8, 16), 3, dtype=self.dtype),
133+
],
134+
dim=1,
135+
)
136+
137+
torch.testing.assert_close(k_out, expected_k_out)
138+
torch.testing.assert_close(v_out, expected_v_out)
139+
140+
def test_update_with_all_shift(self):
141+
# KV cache has no spaces for new tokens, shift all tokens
142+
input_pos = torch.tensor([0], dtype=torch.int32)
143+
k = torch.ones((1, 5, 8, 16), dtype=self.dtype)
144+
v = torch.ones((1, 5, 8, 16), dtype=self.dtype)
145+
146+
self.kv_cache.update(input_pos, k, v)
147+
148+
input_pos = torch.tensor([5], dtype=torch.int32)
149+
k = torch.full((1, 28, 8, 16), 2, dtype=self.dtype)
150+
v = torch.full((1, 28, 8, 16), 2, dtype=self.dtype)
151+
152+
self.kv_cache.update(input_pos, k, v)
153+
154+
input_pos = torch.tensor([33], dtype=torch.int32)
155+
k = torch.full((1, 6, 8, 16), 3, dtype=self.dtype)
156+
v = torch.full((1, 6, 8, 16), 3, dtype=self.dtype)
157+
158+
k_out, v_out = self.kv_cache.update(input_pos, k, v)
159+
160+
expected_k_out = torch.cat(
161+
[
162+
torch.ones((1, 4, 8, 16), dtype=self.dtype),
163+
torch.full((1, 22, 8, 16), 2, dtype=self.dtype),
164+
torch.full((1, 6, 8, 16), 3, dtype=self.dtype),
165+
],
166+
dim=1,
167+
)
168+
expected_v_out = torch.cat(
169+
[
170+
torch.ones((1, 4, 8, 16), dtype=self.dtype),
171+
torch.full((1, 22, 8, 16), 2, dtype=self.dtype),
172+
torch.full((1, 6, 8, 16), 3, dtype=self.dtype),
173+
],
174+
dim=1,
175+
)
176+
177+
torch.testing.assert_close(k_out, expected_k_out)
178+
torch.testing.assert_close(v_out, expected_v_out)

0 commit comments

Comments
 (0)