8
8
9
9
#include < algorithm>
10
10
#include < memory>
11
+ #include < numeric>
11
12
#include < tuple>
12
13
#include < unordered_map>
13
14
#include < vector>
@@ -44,14 +45,17 @@ class StaticKVCache {
44
45
size_t n_heads_per_cache = 1 ,
45
46
StaticAttentionUpdateStyle style = StaticAttentionUpdateStyle::SMART_MASK)
46
47
: n_caches_(n_caches),
47
- cache_len_ (cache_len),
48
+ cache_len_ (n_caches_, cache_len),
49
+ cache_pos_(n_caches_, 0 ),
48
50
max_input_len_(max_input_len),
49
51
n_heads_per_cache_(n_heads_per_cache),
50
52
head_dim_(head_dim),
51
53
style_(style),
52
54
input_ptrs_(n_caches_),
53
55
output_ptrs_(n_caches_) {
54
- cache_data_size_ = n_caches_ * n_heads_per_cache_ * cache_len_ * head_dim_;
56
+ size_t total_cache_len =
57
+ std::accumulate (cache_len_.begin (), cache_len_.end (), 0 );
58
+ cache_data_size_ = total_cache_len * n_heads_per_cache_ * head_dim_;
55
59
update_data_size_ =
56
60
n_caches_ * n_heads_per_cache_ * max_input_len_ * head_dim_;
57
61
@@ -122,7 +126,7 @@ class StaticKVCache {
122
126
ET_CHECK_MSG (inSizes[ndim - 1 ] == head_dim_, " KV head dim mismatch." );
123
127
ET_CHECK_MSG (outSizes[ndim - 1 ] == head_dim_, " KV head dim mismatch." );
124
128
ET_CHECK_MSG (
125
- inSizes[ndim - 2 ] == cache_len_, " Cache length dim mismatch." );
129
+ inSizes[ndim - 2 ] == cache_len_[i] , " Cache length dim mismatch." );
126
130
127
131
auto impl = ::executorch::runtime::etensor::TensorImpl (
128
132
inMeta->scalar_type (),
@@ -150,55 +154,88 @@ class StaticKVCache {
150
154
void update (
151
155
torch::executor::Method& method,
152
156
const std::vector<size_t >& output_indices,
153
- size_t update_len ,
157
+ size_t update_n ,
154
158
size_t update_pos = 0 ) {
155
- if (valid_len_ + update_len > cache_len_) {
156
- throw std::runtime_error (" Cache capacity exceeded." );
157
- }
158
-
159
159
for (size_t i = 0 ; i < n_caches_; i++) {
160
160
const auto & updateTensor =
161
161
method.get_output (output_indices[i]).toTensor ();
162
162
ET_CHECK (output_ptrs_[i] == updateTensor.mutable_data_ptr <T>());
163
- auto update_seq_len = updateTensor.size (updateTensor.dim () - 2 );
164
- for ( size_t j = 0 ; j < n_heads_per_cache_; j++) {
165
- auto * update_head = output_ptrs_[i] + update_seq_len * head_dim_ * j;
166
- auto * cache_head = input_ptrs_[i] + cache_len_ * head_dim_ * j;
167
- std::copy (
168
- update_head + update_pos * head_dim_ ,
169
- update_head + (update_pos + update_len) * head_dim_ ,
170
- cache_head + valid_len_ * head_dim_);
171
- }
163
+ size_t update_len = updateTensor.size (updateTensor.dim () - 2 );
164
+ cache_pos_[i] = update_one_cache (
165
+ output_ptrs_[i],
166
+ update_len,
167
+ update_n,
168
+ update_pos,
169
+ input_ptrs_[i] ,
170
+ cache_len_[i],
171
+ cache_pos_[i]);
172
172
}
173
- valid_len_ += update_len;
174
173
}
175
174
176
175
/* *
177
- * Reset the cache. After this the cache contains no valid data and is ready
178
- * for number of tokens up to the cache length .
176
+ * Reset the cache. After this the cache contains no valid data and the mask
177
+ * should be updated to reflect this .
179
178
*/
180
179
void reset () {
181
- valid_len_ = 0 ;
182
- }
183
-
184
- size_t size () {
185
- return valid_len_;
180
+ std::fill (cache_pos_.begin (), cache_pos_.end (), 0 );
186
181
}
187
182
188
183
private:
189
184
void init_ptrs () {
190
185
input_ptrs_.resize (n_caches_);
191
186
output_ptrs_.resize (n_caches_);
187
+ size_t cache_data_offset = 0 ;
192
188
for (size_t i = 0 ; i < n_caches_; i++) {
193
- input_ptrs_[i] =
194
- cache_data_ + i * n_heads_per_cache_ * cache_len_ * head_dim_;
189
+ input_ptrs_[i] = cache_data_ + cache_data_offset;
190
+ cache_data_offset += cache_len_[i] * n_heads_per_cache_ * head_dim_;
195
191
output_ptrs_[i] =
196
192
update_data_ + i * n_heads_per_cache_ * max_input_len_ * head_dim_;
197
193
}
198
194
}
199
195
196
+ size_t update_one_cache (
197
+ const T* update,
198
+ size_t update_len,
199
+ size_t update_n,
200
+ size_t update_pos,
201
+ T* cache,
202
+ size_t cache_len,
203
+ size_t cache_pos) {
204
+ size_t wrap_n = 0 ;
205
+ auto contiguous_n = cache_len - cache_pos;
206
+ if (update_n > contiguous_n) {
207
+ wrap_n = update_n - contiguous_n;
208
+ update_n = contiguous_n;
209
+ }
210
+
211
+ // Update & cache shape: (1, n_heads, seq_len, head_dim)
212
+ for (size_t head = 0 ; head < n_heads_per_cache_; head++) {
213
+ auto * update_head = update + update_len * head_dim_ * head;
214
+ auto * cache_head = cache + cache_len * head_dim_ * head;
215
+ std::copy (
216
+ update_head + update_pos * head_dim_,
217
+ update_head + (update_pos + update_n) * head_dim_,
218
+ cache_head + cache_pos * head_dim_);
219
+ }
220
+ cache_pos += update_n;
221
+
222
+ if (wrap_n > 0 ) {
223
+ return update_one_cache (
224
+ update,
225
+ update_len,
226
+ wrap_n,
227
+ update_pos + contiguous_n,
228
+ cache,
229
+ cache_len,
230
+ 0 );
231
+ }
232
+
233
+ return cache_pos;
234
+ }
235
+
200
236
size_t n_caches_;
201
- size_t cache_len_;
237
+ std::vector<size_t > cache_len_;
238
+ std::vector<size_t > cache_pos_;
202
239
size_t max_input_len_;
203
240
size_t n_heads_per_cache_;
204
241
size_t head_dim_;
@@ -210,7 +247,6 @@ class StaticKVCache {
210
247
T* update_data_;
211
248
std::vector<T*> input_ptrs_;
212
249
std::vector<T*> output_ptrs_;
213
- size_t valid_len_ = 0 ;
214
250
};
215
251
216
252
template <typename T, typename AllocatorT = std::allocator<T>>
@@ -267,17 +303,20 @@ class StaticAttentionMask {
267
303
}
268
304
269
305
/* *
270
- * Update the mask to indicate update_len elements have been added to the
271
- * cache. Note that update_len might be smaller than input_len_ when
272
- * prefilling with padded inputs.
306
+ * Update the mask to indicate update_n elements have been added to the
307
+ * cache. Note that update_n might be smaller than input_len_ when prefilling
308
+ * with padded inputs.
273
309
*/
274
- void unmask (size_t update_len) {
275
- for (size_t i = 0 ; i < input_len_; i++) {
276
- auto * p = data_ + (cache_len_ + input_len_) * i;
277
- std::fill (
278
- p + cache_valid_len_, p + cache_valid_len_ + update_len, zero_val_);
310
+ void unmask (size_t update_n) {
311
+ update_n = std::min (update_n, cache_len_ - cache_valid_len_);
312
+ if (update_n > 0 ) {
313
+ for (size_t i = 0 ; i < input_len_; i++) {
314
+ auto * p = data_ + (cache_len_ + input_len_) * i;
315
+ std::fill (
316
+ p + cache_valid_len_, p + cache_valid_len_ + update_n, zero_val_);
317
+ }
318
+ cache_valid_len_ += update_n;
279
319
}
280
- cache_valid_len_ += update_len;
281
320
}
282
321
283
322
void set_causal_mask () {
@@ -565,7 +604,7 @@ class StaticAttentionIOManager {
565
604
set_input (method, config_.attn_mask_input_index , mask.get ());
566
605
567
606
std::vector<TokenT> generated_tokens;
568
- while (kCaches_ . size () + 1 <= config_. cache_len ) {
607
+ while (true ) {
569
608
input_buffer[0 ] = prev_tok;
570
609
prepare (method);
571
610
ET_CHECK (method.execute () == executorch::runtime::Error::Ok);
@@ -629,7 +668,7 @@ class StaticAttentionIOManager {
629
668
std::max (window_size * (ngram_size - 1 ), static_cast <size_t >(1 ));
630
669
size_t n_inference = 0 ;
631
670
std::fill (input_buffer.begin (), input_buffer.end (), prev_tok);
632
- while (kCaches_ . size () + 1 <= config_. cache_len ) {
671
+ while (true ) {
633
672
input_buffer[0 ] = prev_tok;
634
673
// Initialize verification branches.
635
674
if (auto it = suffix_caches.find (prev_tok); it != suffix_caches.end ()) {
@@ -697,9 +736,6 @@ class StaticAttentionIOManager {
697
736
input_buffer[branch_offset + j] == match.back ();
698
737
j++) {
699
738
match.emplace_back (output_toks[branch_offset + j]);
700
- if (should_stop (match.back ())) {
701
- break ;
702
- }
703
739
}
704
740
if (match.size () > longest_match.size ()) {
705
741
longest_match = std::move (match);
@@ -724,8 +760,7 @@ class StaticAttentionIOManager {
724
760
method,
725
761
config_.k_cache_output_indices ,
726
762
config_.v_cache_output_indices ,
727
- std::min (
728
- longest_match.size () - 1 , config_.cache_len - kCaches_ .size ()),
763
+ longest_match.size () - 1 ,
729
764
branch_offset);
730
765
}
731
766
0 commit comments