Skip to content

Commit dbbaa85

Browse files
committed
Update on "add attention_sink.py"
This PR adds `KVCacheWithAttentionSink`, which is required for `AttentionSink`. It keeps the first `sink_size` tokens as attention sinks and maintains a sliding window with `window_size` for new tokens. Note: I am trying to implement and verify `AttentionSink` in eager mode first. So the current implementation may still have some lower errors or performance issue. For example, it does not support the case when dynamic shape is disabled. Will leave these problems to resolve when we are ready to deploy `AttentionSink` to edge. Differential Revision: [D65235798](https://our.internmc.facebook.com/intern/diff/D65235798/) [ghstack-poisoned]
2 parents 5de701d + 67aeda2 commit dbbaa85

File tree

3 files changed

+244
-230
lines changed

3 files changed

+244
-230
lines changed

examples/models/llama/TARGETS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,9 @@ runtime.python_test(
220220
srcs = [
221221
"source_transformation/test_attention_sink.py",
222222
],
223+
supports_static_listing = False,
223224
deps = [
225+
"fbsource//third-party/pypi/parameterized:parameterized",
224226
"//caffe2:torch",
225227
":export_library",
226228
],

examples/models/llama/source_transformation/attention_sink.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ def __init__(
100100
self.sink_size = sink_size
101101
self.eviction_batch_size = eviction_batch_size
102102
self.position_shift = 0
103-
assert not transpose_cache
104103

105104
def evict_tokens(self, input_pos: torch.Tensor, seq_len: int) -> int:
106105
"""
@@ -134,16 +133,26 @@ def evict_tokens(self, input_pos: torch.Tensor, seq_len: int) -> int:
134133
self.sink_size + num_to_evict, # pyre-ignore [6]
135134
num_to_keep, # pyre-ignore [6]
136135
)
136+
if self.transpose_cache:
137+
k_to_keep = self.rope.rerotate_k(
138+
k=k_to_keep.transpose(1, 2),
139+
original_position=( # pyre-ignore [6]
140+
self.sink_size + num_to_evict
141+
),
142+
new_position=self.sink_size,
143+
).transpose(1, 2)
144+
else:
145+
k_to_keep = self.rope.rerotate_k(
146+
k=k_to_keep,
147+
original_position=( # pyre-ignore [6]
148+
self.sink_size + num_to_evict
149+
),
150+
new_position=self.sink_size,
151+
)
137152
self.k_cache = torch.cat(
138153
[
139154
self.k_cache.narrow(dim_to_slice, 0, self.sink_size),
140-
self.rope.rerotate_k(
141-
k=k_to_keep,
142-
original_position=( # pyre-ignore [6]
143-
self.sink_size + num_to_evict
144-
),
145-
new_position=self.sink_size,
146-
),
155+
k_to_keep,
147156
torch.zeros_like(
148157
self.k_cache.narrow(
149158
dim_to_slice, 0, num_empty_space # pyre-ignore [6]

0 commit comments

Comments
 (0)