1919namespace example {
2020
2121enum class StaticAttentionUpdateStyle {
22- /* *
23- * KV caches will have valid data at the end of the cache. New elements are
24- * added at the end and the start of the cache will slide forward to maintain
25- * this invariant. This potentially allows shorter caches to be passed into
26- * the model by adjusting the start pointer.
27- */
28- SLIDING_CACHE,
2922 /* *
3023 * I/O pointers do not change which can enable persistent memory mapping
31- * between AP and NPU.
24+ * between AP and NPU. However cache updates need to be copied.
3225 */
3326 SMART_MASK,
3427};
@@ -49,35 +42,16 @@ class StaticKVCache {
4942 size_t head_dim,
5043 size_t max_input_len = 1 ,
5144 size_t n_heads_per_cache = 1 ,
52- bool transpose = false ,
53- StaticAttentionUpdateStyle style =
54- StaticAttentionUpdateStyle::SLIDING_CACHE)
45+ StaticAttentionUpdateStyle style = StaticAttentionUpdateStyle::SMART_MASK)
5546 : n_caches_(n_caches),
5647 cache_len_ (cache_len),
5748 max_input_len_(max_input_len),
5849 n_heads_per_cache_(n_heads_per_cache),
5950 head_dim_(head_dim),
60- transpose_(transpose),
6151 style_(style),
6252 input_ptrs_(n_caches_),
6353 output_ptrs_(n_caches_) {
64- if (transpose_) {
65- throw std::runtime_error (" Not implemented." );
66- }
67-
68- if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE &&
69- n_heads_per_cache_ > 1 ) {
70- throw std::runtime_error (
71- " Sliding cache update strategy doesn't support more than one head per cache tensor." );
72- }
73-
74- if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE) {
75- // Allocates on extra copy to accomodate caches sliding forward.
76- cache_data_size_ = (n_caches_ + 1 ) * cache_len_ * head_dim_;
77- } else {
78- cache_data_size_ =
79- n_caches_ * n_heads_per_cache_ * cache_len_ * head_dim_;
80- }
54+ cache_data_size_ = n_caches_ * n_heads_per_cache_ * cache_len_ * head_dim_;
8155 update_data_size_ =
8256 n_caches_ * n_heads_per_cache_ * max_input_len_ * head_dim_;
8357
@@ -145,17 +119,10 @@ class StaticKVCache {
145119 " Cache input tensor expected to have rank 3." );
146120 }
147121 auto ndim = inSizes.size ();
148- if (transpose_) {
149- ET_CHECK_MSG (inSizes[ndim - 2 ] == head_dim_, " KV head dim mismatch." );
150- ET_CHECK_MSG (outSizes[ndim - 2 ] == head_dim_, " KV head dim mismatch." );
151- ET_CHECK_MSG (
152- inSizes[ndim - 1 ] == cache_len_, " Cache length dim mismatch." );
153- } else {
154- ET_CHECK_MSG (inSizes[ndim - 1 ] == head_dim_, " KV head dim mismatch." );
155- ET_CHECK_MSG (outSizes[ndim - 1 ] == head_dim_, " KV head dim mismatch." );
156- ET_CHECK_MSG (
157- inSizes[ndim - 2 ] == cache_len_, " Cache length dim mismatch." );
158- }
122+ ET_CHECK_MSG (inSizes[ndim - 1 ] == head_dim_, " KV head dim mismatch." );
123+ ET_CHECK_MSG (outSizes[ndim - 1 ] == head_dim_, " KV head dim mismatch." );
124+ ET_CHECK_MSG (
125+ inSizes[ndim - 2 ] == cache_len_, " Cache length dim mismatch." );
159126
160127 auto impl = ::executorch::runtime::etensor::TensorImpl (
161128 inMeta->scalar_type (),
@@ -189,11 +156,21 @@ class StaticKVCache {
189156 throw std::runtime_error (" Cache capacity exceeded." );
190157 }
191158
192- if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE) {
193- update_sliding_cache (method, output_indices, update_len, update_pos);
194- } else {
195- update_smart_mask (method, output_indices, update_len, update_pos);
159+ for (size_t i = 0 ; i < n_caches_; i++) {
160+ const auto & updateTensor =
161+ method.get_output (output_indices[i]).toTensor ();
162+ ET_CHECK (output_ptrs_[i] == updateTensor.mutable_data_ptr <T>());
163+ auto update_seq_len = updateTensor.size (updateTensor.dim () - 2 );
164+ for (size_t j = 0 ; j < n_heads_per_cache_; j++) {
165+ auto * update_head = output_ptrs_[i] + update_seq_len * head_dim_ * j;
166+ auto * cache_head = input_ptrs_[i] + cache_len_ * head_dim_ * j;
167+ std::copy (
168+ update_head + update_pos * head_dim_,
169+ update_head + (update_pos + update_len) * head_dim_,
170+ cache_head + valid_len_ * head_dim_);
171+ }
196172 }
173+ valid_len_ += update_len;
197174 }
198175
199176 /* *
@@ -202,9 +179,6 @@ class StaticKVCache {
202179 */
203180 void reset () {
204181 valid_len_ = 0 ;
205- if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE) {
206- init_ptrs ();
207- }
208182 }
209183
210184 size_t size () {
@@ -223,53 +197,11 @@ class StaticKVCache {
223197 }
224198 }
225199
226- void update_sliding_cache (
227- torch::executor::Method& method,
228- const std::vector<size_t >& output_indices,
229- size_t update_len,
230- size_t update_pos) {
231- ET_CHECK (n_caches_ == output_indices.size ());
232- for (size_t i = 0 ; i < n_caches_; i++) {
233- const auto & updateTensor =
234- method.get_output (output_indices[i]).toTensor ();
235- ET_CHECK (output_ptrs_[i] == updateTensor.const_data_ptr <T>());
236- std::copy (
237- output_ptrs_[i] + update_pos * head_dim_,
238- output_ptrs_[i] + (update_pos + update_len) * head_dim_,
239- input_ptrs_[i] + cache_len_ * head_dim_);
240- input_ptrs_[i] += update_len * head_dim_;
241- }
242- valid_len_ += update_len;
243- }
244-
245- void update_smart_mask (
246- torch::executor::Method& method,
247- const std::vector<size_t >& output_indices,
248- size_t update_len,
249- size_t update_pos) {
250- for (size_t i = 0 ; i < n_caches_; i++) {
251- const auto & updateTensor =
252- method.get_output (output_indices[i]).toTensor ();
253- ET_CHECK (output_ptrs_[i] == updateTensor.mutable_data_ptr <T>());
254- auto update_seq_len = updateTensor.size (updateTensor.dim () - 2 );
255- for (size_t j = 0 ; j < n_heads_per_cache_; j++) {
256- auto * update_head = output_ptrs_[i] + update_seq_len * head_dim_ * j;
257- auto * cache_head = input_ptrs_[i] + cache_len_ * head_dim_ * j;
258- std::copy (
259- update_head + update_pos * head_dim_,
260- update_head + (update_pos + update_len) * head_dim_,
261- cache_head + valid_len_ * head_dim_);
262- }
263- }
264- valid_len_ += update_len;
265- }
266-
267200 size_t n_caches_;
268201 size_t cache_len_;
269202 size_t max_input_len_;
270203 size_t n_heads_per_cache_;
271204 size_t head_dim_;
272- bool transpose_;
273205 StaticAttentionUpdateStyle style_;
274206 AllocatorT allocator_;
275207 size_t cache_data_size_;
@@ -300,8 +232,7 @@ class StaticAttentionMask {
300232 size_t head_dim,
301233 T zero_val,
302234 T mask_val,
303- StaticAttentionUpdateStyle style =
304- StaticAttentionUpdateStyle::SLIDING_CACHE)
235+ StaticAttentionUpdateStyle style = StaticAttentionUpdateStyle::SMART_MASK)
305236 : cache_len_(cache_len),
306237 input_len_ (input_len),
307238 head_dim_(head_dim),
@@ -341,20 +272,10 @@ class StaticAttentionMask {
341272 * prefilling with padded inputs.
342273 */
343274 void unmask (size_t update_len) {
344- if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE) {
345- for (size_t i = 0 ; i < input_len_; i++) {
346- auto * p = data_ + (cache_len_ + input_len_) * i;
347- std::fill (
348- p + cache_len_ - cache_valid_len_ - update_len,
349- p + cache_len_ - cache_valid_len_,
350- zero_val_);
351- }
352- } else {
353- for (size_t i = 0 ; i < input_len_; i++) {
354- auto * p = data_ + (cache_len_ + input_len_) * i;
355- std::fill (
356- p + cache_valid_len_, p + cache_valid_len_ + update_len, zero_val_);
357- }
275+ for (size_t i = 0 ; i < input_len_; i++) {
276+ auto * p = data_ + (cache_len_ + input_len_) * i;
277+ std::fill (
278+ p + cache_valid_len_, p + cache_valid_len_ + update_len, zero_val_);
358279 }
359280 cache_valid_len_ += update_len;
360281 }
@@ -468,8 +389,7 @@ class StaticAttentionIOManager {
468389 std::vector<size_t > v_cache_output_indices;
469390 RopeT* rope_freqs_cos;
470391 RopeT* rope_freqs_sin;
471- StaticAttentionUpdateStyle style =
472- StaticAttentionUpdateStyle::SLIDING_CACHE;
392+ StaticAttentionUpdateStyle style = StaticAttentionUpdateStyle::SMART_MASK;
473393 };
474394
475395 StaticAttentionIOManager (StaticAttentionIOConfig config)
@@ -480,15 +400,13 @@ class StaticAttentionIOManager {
480400 config_.head_dim,
481401 config_.max_input_len,
482402 config_.n_heads_per_cache,
483- false ,
484403 config_.style),
485404 vCaches_(
486405 config_.n_caches,
487406 config_.cache_len,
488407 config_.head_dim,
489408 config_.max_input_len,
490409 config_.n_heads_per_cache,
491- false ,
492410 config_.style) {
493411 ET_LOG (
494412 Info,
0 commit comments