Skip to content

Commit 7140dec

Browse files
committed
Update on "add attention_sink.py"
This PR adds `KVCacheWithAttentionSink`, which is required for `AttentionSink`. It keeps the first `sink_size` tokens as attention sinks and maintains a sliding window with `window_size` for new tokens. Note: I am trying to implement and verify `AttentionSink` in eager mode first. So the current implementation may still have some lower errors or performance issue. For example, it does not support the case when dynamic shape is disabled. Will leave these problems to resolve when we are ready to deploy `AttentionSink` to edge. Differential Revision: [D65235798](https://our.internmc.facebook.com/intern/diff/D65235798/) [ghstack-poisoned]
1 parent 47dcd57 commit 7140dec

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

examples/models/llama/source_transformation/attention_sink.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -79,25 +79,25 @@ def update(
7979
num_to_evict = min(start_pos + seq_length - self.cache_size, seq_length)
8080

8181
# Shift the existing entries to the left
82-
# pyre-ignore: Incompatible parameter type [6]
8382
k_to_keep = self.k_cache.narrow(
8483
dim_to_slice,
85-
self.sink_size + num_to_evict,
86-
self.window_size - num_to_evict,
84+
self.sink_size + num_to_evict, # pyre-ignore [6]
85+
self.window_size - num_to_evict, # pyre-ignore [6]
8786
).clone()
88-
# pyre-ignore: Incompatible parameter type [6]
8987
v_to_keep = self.v_cache.narrow(
9088
dim_to_slice,
91-
self.sink_size + num_to_evict,
92-
self.window_size - num_to_evict,
89+
self.sink_size + num_to_evict, # pyre-ignore [6]
90+
self.window_size - num_to_evict, # pyre-ignore [6]
9391
).clone()
94-
# pyre-ignore: Incompatible parameter type [6]
9592
k_new_position = self.k_cache.narrow(
96-
dim_to_slice, self.sink_size, self.window_size - num_to_evict
93+
dim_to_slice,
94+
self.sink_size,
95+
self.window_size - num_to_evict, # pyre-ignore [6]
9796
)
98-
# pyre-ignore: Incompatible parameter type [6]
9997
v_new_position = self.v_cache.narrow(
100-
dim_to_slice, self.sink_size, self.window_size - num_to_evict
98+
dim_to_slice,
99+
self.sink_size,
100+
self.window_size - num_to_evict, # pyre-ignore [6]
101101
)
102102
k_new_position.copy_(k_to_keep)
103103
v_new_position.copy_(v_to_keep)

0 commit comments

Comments
 (0)