Skip to content

Commit cd34c47

Browse files
authored
Sliding window KV cache
Differential Revision: D79128246 Pull Request resolved: pytorch#12975
1 parent 8170f8f commit cd34c47

File tree

1 file changed

+80
-45
lines changed

1 file changed

+80
-45
lines changed

examples/models/llama/runner/static_attention_io_manager.h

Lines changed: 80 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <algorithm>
1010
#include <memory>
11+
#include <numeric>
1112
#include <tuple>
1213
#include <unordered_map>
1314
#include <vector>
@@ -44,14 +45,17 @@ class StaticKVCache {
4445
size_t n_heads_per_cache = 1,
4546
StaticAttentionUpdateStyle style = StaticAttentionUpdateStyle::SMART_MASK)
4647
: n_caches_(n_caches),
47-
cache_len_(cache_len),
48+
cache_len_(n_caches_, cache_len),
49+
cache_pos_(n_caches_, 0),
4850
max_input_len_(max_input_len),
4951
n_heads_per_cache_(n_heads_per_cache),
5052
head_dim_(head_dim),
5153
style_(style),
5254
input_ptrs_(n_caches_),
5355
output_ptrs_(n_caches_) {
54-
cache_data_size_ = n_caches_ * n_heads_per_cache_ * cache_len_ * head_dim_;
56+
size_t total_cache_len =
57+
std::accumulate(cache_len_.begin(), cache_len_.end(), 0);
58+
cache_data_size_ = total_cache_len * n_heads_per_cache_ * head_dim_;
5559
update_data_size_ =
5660
n_caches_ * n_heads_per_cache_ * max_input_len_ * head_dim_;
5761

@@ -122,7 +126,7 @@ class StaticKVCache {
122126
ET_CHECK_MSG(inSizes[ndim - 1] == head_dim_, "KV head dim mismatch.");
123127
ET_CHECK_MSG(outSizes[ndim - 1] == head_dim_, "KV head dim mismatch.");
124128
ET_CHECK_MSG(
125-
inSizes[ndim - 2] == cache_len_, "Cache length dim mismatch.");
129+
inSizes[ndim - 2] == cache_len_[i], "Cache length dim mismatch.");
126130

127131
auto impl = ::executorch::runtime::etensor::TensorImpl(
128132
inMeta->scalar_type(),
@@ -150,55 +154,88 @@ class StaticKVCache {
150154
void update(
151155
torch::executor::Method& method,
152156
const std::vector<size_t>& output_indices,
153-
size_t update_len,
157+
size_t update_n,
154158
size_t update_pos = 0) {
155-
if (valid_len_ + update_len > cache_len_) {
156-
throw std::runtime_error("Cache capacity exceeded.");
157-
}
158-
159159
for (size_t i = 0; i < n_caches_; i++) {
160160
const auto& updateTensor =
161161
method.get_output(output_indices[i]).toTensor();
162162
ET_CHECK(output_ptrs_[i] == updateTensor.mutable_data_ptr<T>());
163-
auto update_seq_len = updateTensor.size(updateTensor.dim() - 2);
164-
for (size_t j = 0; j < n_heads_per_cache_; j++) {
165-
auto* update_head = output_ptrs_[i] + update_seq_len * head_dim_ * j;
166-
auto* cache_head = input_ptrs_[i] + cache_len_ * head_dim_ * j;
167-
std::copy(
168-
update_head + update_pos * head_dim_,
169-
update_head + (update_pos + update_len) * head_dim_,
170-
cache_head + valid_len_ * head_dim_);
171-
}
163+
size_t update_len = updateTensor.size(updateTensor.dim() - 2);
164+
cache_pos_[i] = update_one_cache(
165+
output_ptrs_[i],
166+
update_len,
167+
update_n,
168+
update_pos,
169+
input_ptrs_[i],
170+
cache_len_[i],
171+
cache_pos_[i]);
172172
}
173-
valid_len_ += update_len;
174173
}
175174

176175
/**
177-
* Reset the cache. After this the cache contains no valid data and is ready
178-
* for number of tokens up to the cache length.
176+
* Reset the cache. After this the cache contains no valid data and the mask
177+
* should be updated to reflect this.
179178
*/
180179
void reset() {
181-
valid_len_ = 0;
182-
}
183-
184-
size_t size() {
185-
return valid_len_;
180+
std::fill(cache_pos_.begin(), cache_pos_.end(), 0);
186181
}
187182

188183
private:
189184
void init_ptrs() {
190185
input_ptrs_.resize(n_caches_);
191186
output_ptrs_.resize(n_caches_);
187+
size_t cache_data_offset = 0;
192188
for (size_t i = 0; i < n_caches_; i++) {
193-
input_ptrs_[i] =
194-
cache_data_ + i * n_heads_per_cache_ * cache_len_ * head_dim_;
189+
input_ptrs_[i] = cache_data_ + cache_data_offset;
190+
cache_data_offset += cache_len_[i] * n_heads_per_cache_ * head_dim_;
195191
output_ptrs_[i] =
196192
update_data_ + i * n_heads_per_cache_ * max_input_len_ * head_dim_;
197193
}
198194
}
199195

196+
size_t update_one_cache(
197+
const T* update,
198+
size_t update_len,
199+
size_t update_n,
200+
size_t update_pos,
201+
T* cache,
202+
size_t cache_len,
203+
size_t cache_pos) {
204+
size_t wrap_n = 0;
205+
auto contiguous_n = cache_len - cache_pos;
206+
if (update_n > contiguous_n) {
207+
wrap_n = update_n - contiguous_n;
208+
update_n = contiguous_n;
209+
}
210+
211+
// Update & cache shape: (1, n_heads, seq_len, head_dim)
212+
for (size_t head = 0; head < n_heads_per_cache_; head++) {
213+
auto* update_head = update + update_len * head_dim_ * head;
214+
auto* cache_head = cache + cache_len * head_dim_ * head;
215+
std::copy(
216+
update_head + update_pos * head_dim_,
217+
update_head + (update_pos + update_n) * head_dim_,
218+
cache_head + cache_pos * head_dim_);
219+
}
220+
cache_pos += update_n;
221+
222+
if (wrap_n > 0) {
223+
return update_one_cache(
224+
update,
225+
update_len,
226+
wrap_n,
227+
update_pos + contiguous_n,
228+
cache,
229+
cache_len,
230+
0);
231+
}
232+
233+
return cache_pos;
234+
}
235+
200236
size_t n_caches_;
201-
size_t cache_len_;
237+
std::vector<size_t> cache_len_;
238+
std::vector<size_t> cache_pos_;
202239
size_t max_input_len_;
203240
size_t n_heads_per_cache_;
204241
size_t head_dim_;
@@ -210,7 +247,6 @@ class StaticKVCache {
210247
T* update_data_;
211248
std::vector<T*> input_ptrs_;
212249
std::vector<T*> output_ptrs_;
213-
size_t valid_len_ = 0;
214250
};
215251

216252
template <typename T, typename AllocatorT = std::allocator<T>>
@@ -267,17 +303,20 @@ class StaticAttentionMask {
267303
}
268304

269305
/**
270-
* Update the mask to indicate update_len elements have been added to the
271-
* cache. Note that update_len might be smaller than input_len_ when
272-
* prefilling with padded inputs.
306+
* Update the mask to indicate update_n elements have been added to the
307+
* cache. Note that update_n might be smaller than input_len_ when prefilling
308+
* with padded inputs.
273309
*/
274-
void unmask(size_t update_len) {
275-
for (size_t i = 0; i < input_len_; i++) {
276-
auto* p = data_ + (cache_len_ + input_len_) * i;
277-
std::fill(
278-
p + cache_valid_len_, p + cache_valid_len_ + update_len, zero_val_);
310+
void unmask(size_t update_n) {
311+
update_n = std::min(update_n, cache_len_ - cache_valid_len_);
312+
if (update_n > 0) {
313+
for (size_t i = 0; i < input_len_; i++) {
314+
auto* p = data_ + (cache_len_ + input_len_) * i;
315+
std::fill(
316+
p + cache_valid_len_, p + cache_valid_len_ + update_n, zero_val_);
317+
}
318+
cache_valid_len_ += update_n;
279319
}
280-
cache_valid_len_ += update_len;
281320
}
282321

283322
void set_causal_mask() {
@@ -565,7 +604,7 @@ class StaticAttentionIOManager {
565604
set_input(method, config_.attn_mask_input_index, mask.get());
566605

567606
std::vector<TokenT> generated_tokens;
568-
while (kCaches_.size() + 1 <= config_.cache_len) {
607+
while (true) {
569608
input_buffer[0] = prev_tok;
570609
prepare(method);
571610
ET_CHECK(method.execute() == executorch::runtime::Error::Ok);
@@ -629,7 +668,7 @@ class StaticAttentionIOManager {
629668
std::max(window_size * (ngram_size - 1), static_cast<size_t>(1));
630669
size_t n_inference = 0;
631670
std::fill(input_buffer.begin(), input_buffer.end(), prev_tok);
632-
while (kCaches_.size() + 1 <= config_.cache_len) {
671+
while (true) {
633672
input_buffer[0] = prev_tok;
634673
// Initialize verification branches.
635674
if (auto it = suffix_caches.find(prev_tok); it != suffix_caches.end()) {
@@ -697,9 +736,6 @@ class StaticAttentionIOManager {
697736
input_buffer[branch_offset + j] == match.back();
698737
j++) {
699738
match.emplace_back(output_toks[branch_offset + j]);
700-
if (should_stop(match.back())) {
701-
break;
702-
}
703739
}
704740
if (match.size() > longest_match.size()) {
705741
longest_match = std::move(match);
@@ -724,8 +760,7 @@ class StaticAttentionIOManager {
724760
method,
725761
config_.k_cache_output_indices,
726762
config_.v_cache_output_indices,
727-
std::min(
728-
longest_match.size() - 1, config_.cache_len - kCaches_.size()),
763+
longest_match.size() - 1,
729764
branch_offset);
730765
}
731766

0 commit comments

Comments
 (0)