88
99#include < algorithm>
1010#include < memory>
11+ #include < numeric>
1112#include < tuple>
1213#include < unordered_map>
1314#include < vector>
@@ -44,14 +45,17 @@ class StaticKVCache {
4445 size_t n_heads_per_cache = 1 ,
4546 StaticAttentionUpdateStyle style = StaticAttentionUpdateStyle::SMART_MASK)
4647 : n_caches_(n_caches),
47- cache_len_ (cache_len),
48+ cache_len_ (n_caches_, cache_len),
49+ cache_pos_(n_caches_, 0 ),
4850 max_input_len_(max_input_len),
4951 n_heads_per_cache_(n_heads_per_cache),
5052 head_dim_(head_dim),
5153 style_(style),
5254 input_ptrs_(n_caches_),
5355 output_ptrs_(n_caches_) {
54- cache_data_size_ = n_caches_ * n_heads_per_cache_ * cache_len_ * head_dim_;
56+ size_t total_cache_len =
57+ std::accumulate (cache_len_.begin (), cache_len_.end (), 0 );
58+ cache_data_size_ = total_cache_len * n_heads_per_cache_ * head_dim_;
5559 update_data_size_ =
5660 n_caches_ * n_heads_per_cache_ * max_input_len_ * head_dim_;
5761
@@ -122,7 +126,7 @@ class StaticKVCache {
122126 ET_CHECK_MSG (inSizes[ndim - 1 ] == head_dim_, " KV head dim mismatch." );
123127 ET_CHECK_MSG (outSizes[ndim - 1 ] == head_dim_, " KV head dim mismatch." );
124128 ET_CHECK_MSG (
125- inSizes[ndim - 2 ] == cache_len_, " Cache length dim mismatch." );
129+ inSizes[ndim - 2 ] == cache_len_[i] , " Cache length dim mismatch." );
126130
127131 auto impl = ::executorch::runtime::etensor::TensorImpl (
128132 inMeta->scalar_type (),
@@ -150,55 +154,88 @@ class StaticKVCache {
150154 void update (
151155 torch::executor::Method& method,
152156 const std::vector<size_t >& output_indices,
153- size_t update_len ,
157+ size_t update_n ,
154158 size_t update_pos = 0 ) {
155- if (valid_len_ + update_len > cache_len_) {
156- throw std::runtime_error (" Cache capacity exceeded." );
157- }
158-
159159 for (size_t i = 0 ; i < n_caches_; i++) {
160160 const auto & updateTensor =
161161 method.get_output (output_indices[i]).toTensor ();
162162 ET_CHECK (output_ptrs_[i] == updateTensor.mutable_data_ptr <T>());
163- auto update_seq_len = updateTensor.size (updateTensor.dim () - 2 );
164- for ( size_t j = 0 ; j < n_heads_per_cache_; j++) {
165- auto * update_head = output_ptrs_[i] + update_seq_len * head_dim_ * j;
166- auto * cache_head = input_ptrs_[i] + cache_len_ * head_dim_ * j;
167- std::copy (
168- update_head + update_pos * head_dim_ ,
169- update_head + (update_pos + update_len) * head_dim_ ,
170- cache_head + valid_len_ * head_dim_);
171- }
163+ size_t update_len = updateTensor.size (updateTensor.dim () - 2 );
164+ cache_pos_[i] = update_one_cache (
165+ output_ptrs_[i],
166+ update_len,
167+ update_n,
168+ update_pos,
169+ input_ptrs_[i] ,
170+ cache_len_[i],
171+ cache_pos_[i]);
172172 }
173- valid_len_ += update_len;
174173 }
175174
176175 /* *
177- * Reset the cache. After this the cache contains no valid data and is ready
178- * for number of tokens up to the cache length .
176+ * Reset the cache. After this the cache contains no valid data and the mask
177+ * should be updated to reflect this .
179178 */
180179 void reset () {
181- valid_len_ = 0 ;
182- }
183-
184- size_t size () {
185- return valid_len_;
180+ std::fill (cache_pos_.begin (), cache_pos_.end (), 0 );
186181 }
187182
188183 private:
189184 void init_ptrs () {
190185 input_ptrs_.resize (n_caches_);
191186 output_ptrs_.resize (n_caches_);
187+ size_t cache_data_offset = 0 ;
192188 for (size_t i = 0 ; i < n_caches_; i++) {
193- input_ptrs_[i] =
194- cache_data_ + i * n_heads_per_cache_ * cache_len_ * head_dim_;
189+ input_ptrs_[i] = cache_data_ + cache_data_offset;
190+ cache_data_offset += cache_len_[i] * n_heads_per_cache_ * head_dim_;
195191 output_ptrs_[i] =
196192 update_data_ + i * n_heads_per_cache_ * max_input_len_ * head_dim_;
197193 }
198194 }
199195
196+ size_t update_one_cache (
197+ const T* update,
198+ size_t update_len,
199+ size_t update_n,
200+ size_t update_pos,
201+ T* cache,
202+ size_t cache_len,
203+ size_t cache_pos) {
204+ size_t wrap_n = 0 ;
205+ auto contiguous_n = cache_len - cache_pos;
206+ if (update_n > contiguous_n) {
207+ wrap_n = update_n - contiguous_n;
208+ update_n = contiguous_n;
209+ }
210+
211+ // Update & cache shape: (1, n_heads, seq_len, head_dim)
212+ for (size_t head = 0 ; head < n_heads_per_cache_; head++) {
213+ auto * update_head = update + update_len * head_dim_ * head;
214+ auto * cache_head = cache + cache_len * head_dim_ * head;
215+ std::copy (
216+ update_head + update_pos * head_dim_,
217+ update_head + (update_pos + update_n) * head_dim_,
218+ cache_head + cache_pos * head_dim_);
219+ }
220+ cache_pos += update_n;
221+
222+ if (wrap_n > 0 ) {
223+ return update_one_cache (
224+ update,
225+ update_len,
226+ wrap_n,
227+ update_pos + contiguous_n,
228+ cache,
229+ cache_len,
230+ 0 );
231+ }
232+
233+ return cache_pos;
234+ }
235+
200236 size_t n_caches_;
201- size_t cache_len_;
237+ std::vector<size_t > cache_len_;
238+ std::vector<size_t > cache_pos_;
202239 size_t max_input_len_;
203240 size_t n_heads_per_cache_;
204241 size_t head_dim_;
@@ -210,7 +247,6 @@ class StaticKVCache {
210247 T* update_data_;
211248 std::vector<T*> input_ptrs_;
212249 std::vector<T*> output_ptrs_;
213- size_t valid_len_ = 0 ;
214250};
215251
216252template <typename T, typename AllocatorT = std::allocator<T>>
@@ -267,17 +303,20 @@ class StaticAttentionMask {
267303 }
268304
269305 /* *
270- * Update the mask to indicate update_len elements have been added to the
271- * cache. Note that update_len might be smaller than input_len_ when
272- * prefilling with padded inputs.
306+ * Update the mask to indicate update_n elements have been added to the
307+ * cache. Note that update_n might be smaller than input_len_ when prefilling
308+ * with padded inputs.
273309 */
274- void unmask (size_t update_len) {
275- for (size_t i = 0 ; i < input_len_; i++) {
276- auto * p = data_ + (cache_len_ + input_len_) * i;
277- std::fill (
278- p + cache_valid_len_, p + cache_valid_len_ + update_len, zero_val_);
310+ void unmask (size_t update_n) {
311+ update_n = std::min (update_n, cache_len_ - cache_valid_len_);
312+ if (update_n > 0 ) {
313+ for (size_t i = 0 ; i < input_len_; i++) {
314+ auto * p = data_ + (cache_len_ + input_len_) * i;
315+ std::fill (
316+ p + cache_valid_len_, p + cache_valid_len_ + update_n, zero_val_);
317+ }
318+ cache_valid_len_ += update_n;
279319 }
280- cache_valid_len_ += update_len;
281320 }
282321
283322 void set_causal_mask () {
@@ -565,7 +604,7 @@ class StaticAttentionIOManager {
565604 set_input (method, config_.attn_mask_input_index , mask.get ());
566605
567606 std::vector<TokenT> generated_tokens;
568- while (kCaches_ . size () + 1 <= config_. cache_len ) {
607+ while (true ) {
569608 input_buffer[0 ] = prev_tok;
570609 prepare (method);
571610 ET_CHECK (method.execute () == executorch::runtime::Error::Ok);
@@ -629,7 +668,7 @@ class StaticAttentionIOManager {
629668 std::max (window_size * (ngram_size - 1 ), static_cast <size_t >(1 ));
630669 size_t n_inference = 0 ;
631670 std::fill (input_buffer.begin (), input_buffer.end (), prev_tok);
632- while (kCaches_ . size () + 1 <= config_. cache_len ) {
671+ while (true ) {
633672 input_buffer[0 ] = prev_tok;
634673 // Initialize verification branches.
635674 if (auto it = suffix_caches.find (prev_tok); it != suffix_caches.end ()) {
@@ -697,9 +736,6 @@ class StaticAttentionIOManager {
697736 input_buffer[branch_offset + j] == match.back ();
698737 j++) {
699738 match.emplace_back (output_toks[branch_offset + j]);
700- if (should_stop (match.back ())) {
701- break ;
702- }
703739 }
704740 if (match.size () > longest_match.size ()) {
705741 longest_match = std::move (match);
@@ -724,8 +760,7 @@ class StaticAttentionIOManager {
724760 method,
725761 config_.k_cache_output_indices ,
726762 config_.v_cache_output_indices ,
727- std::min (
728- longest_match.size () - 1 , config_.cache_len - kCaches_ .size ()),
763+ longest_match.size () - 1 ,
729764 branch_offset);
730765 }
731766
0 commit comments