1919namespace example {
2020
2121enum class StaticAttentionUpdateStyle {
22- /* *
23- * KV caches will have valid data at the end of the cache. New elements are
24- * added at the end and the start of the cache will slide forward to maintain
25- * this invariant. This potentially allows shorter caches to be passed into
26- * the model by adjusting the start pointer.
27- */
28- SLIDING_CACHE,
2922 /* *
3023 * I/O pointers do not change which can enable persistent memory mapping
31- * between AP and NPU.
24+ * between AP and NPU. However cache updates need to be copied.
3225 */
3326 SMART_MASK,
3427};
@@ -49,35 +42,17 @@ class StaticKVCache {
4942 size_t head_dim,
5043 size_t max_input_len = 1 ,
5144 size_t n_heads_per_cache = 1 ,
52- bool transpose = false ,
5345 StaticAttentionUpdateStyle style =
54- StaticAttentionUpdateStyle::SLIDING_CACHE )
46+ StaticAttentionUpdateStyle::SMART_MASK )
5547 : n_caches_(n_caches),
5648 cache_len_ (cache_len),
5749 max_input_len_(max_input_len),
5850 n_heads_per_cache_(n_heads_per_cache),
5951 head_dim_(head_dim),
60- transpose_(transpose),
6152 style_(style),
6253 input_ptrs_(n_caches_),
6354 output_ptrs_(n_caches_) {
64- if (transpose_) {
65- throw std::runtime_error (" Not implemented." );
66- }
67-
68- if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE &&
69- n_heads_per_cache_ > 1 ) {
70- throw std::runtime_error (
71- " Sliding cache update strategy doesn't support more than one head per cache tensor." );
72- }
73-
74- if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE) {
75- // Allocates on extra copy to accomodate caches sliding forward.
76- cache_data_size_ = (n_caches_ + 1 ) * cache_len_ * head_dim_;
77- } else {
78- cache_data_size_ =
79- n_caches_ * n_heads_per_cache_ * cache_len_ * head_dim_;
80- }
55+ cache_data_size_ = n_caches_ * n_heads_per_cache_ * cache_len_ * head_dim_;
8156 update_data_size_ =
8257 n_caches_ * n_heads_per_cache_ * max_input_len_ * head_dim_;
8358
@@ -145,17 +120,10 @@ class StaticKVCache {
145120 " Cache input tensor expected to have rank 3." );
146121 }
147122 auto ndim = inSizes.size ();
148- if (transpose_) {
149- ET_CHECK_MSG (inSizes[ndim - 2 ] == head_dim_, " KV head dim mismatch." );
150- ET_CHECK_MSG (outSizes[ndim - 2 ] == head_dim_, " KV head dim mismatch." );
151- ET_CHECK_MSG (
152- inSizes[ndim - 1 ] == cache_len_, " Cache length dim mismatch." );
153- } else {
154- ET_CHECK_MSG (inSizes[ndim - 1 ] == head_dim_, " KV head dim mismatch." );
155- ET_CHECK_MSG (outSizes[ndim - 1 ] == head_dim_, " KV head dim mismatch." );
156- ET_CHECK_MSG (
157- inSizes[ndim - 2 ] == cache_len_, " Cache length dim mismatch." );
158- }
123+ ET_CHECK_MSG (inSizes[ndim - 1 ] == head_dim_, " KV head dim mismatch." );
124+ ET_CHECK_MSG (outSizes[ndim - 1 ] == head_dim_, " KV head dim mismatch." );
125+ ET_CHECK_MSG (
126+ inSizes[ndim - 2 ] == cache_len_, " Cache length dim mismatch." );
159127
160128 auto impl = ::executorch::runtime::etensor::TensorImpl (
161129 inMeta->scalar_type (),
@@ -189,11 +157,21 @@ class StaticKVCache {
189157 throw std::runtime_error (" Cache capacity exceeded." );
190158 }
191159
192- if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE) {
193- update_sliding_cache (method, output_indices, update_len, update_pos);
194- } else {
195- update_smart_mask (method, output_indices, update_len, update_pos);
160+ for (size_t i = 0 ; i < n_caches_; i++) {
161+ const auto & updateTensor =
162+ method.get_output (output_indices[i]).toTensor ();
163+ ET_CHECK (output_ptrs_[i] == updateTensor.mutable_data_ptr <T>());
164+ auto update_seq_len = updateTensor.size (updateTensor.dim () - 2 );
165+ for (size_t j = 0 ; j < n_heads_per_cache_; j++) {
166+ auto * update_head = output_ptrs_[i] + update_seq_len * head_dim_ * j;
167+ auto * cache_head = input_ptrs_[i] + cache_len_ * head_dim_ * j;
168+ std::copy (
169+ update_head + update_pos * head_dim_,
170+ update_head + (update_pos + update_len) * head_dim_,
171+ cache_head + valid_len_ * head_dim_);
172+ }
196173 }
174+ valid_len_ += update_len;
197175 }
198176
199177 /* *
@@ -202,9 +180,6 @@ class StaticKVCache {
202180 */
203181 void reset () {
204182 valid_len_ = 0 ;
205- if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE) {
206- init_ptrs ();
207- }
208183 }
209184
210185 size_t size () {
@@ -223,53 +198,11 @@ class StaticKVCache {
223198 }
224199 }
225200
226- void update_sliding_cache (
227- torch::executor::Method& method,
228- const std::vector<size_t >& output_indices,
229- size_t update_len,
230- size_t update_pos) {
231- ET_CHECK (n_caches_ == output_indices.size ());
232- for (size_t i = 0 ; i < n_caches_; i++) {
233- const auto & updateTensor =
234- method.get_output (output_indices[i]).toTensor ();
235- ET_CHECK (output_ptrs_[i] == updateTensor.const_data_ptr <T>());
236- std::copy (
237- output_ptrs_[i] + update_pos * head_dim_,
238- output_ptrs_[i] + (update_pos + update_len) * head_dim_,
239- input_ptrs_[i] + cache_len_ * head_dim_);
240- input_ptrs_[i] += update_len * head_dim_;
241- }
242- valid_len_ += update_len;
243- }
244-
245- void update_smart_mask (
246- torch::executor::Method& method,
247- const std::vector<size_t >& output_indices,
248- size_t update_len,
249- size_t update_pos) {
250- for (size_t i = 0 ; i < n_caches_; i++) {
251- const auto & updateTensor =
252- method.get_output (output_indices[i]).toTensor ();
253- ET_CHECK (output_ptrs_[i] == updateTensor.mutable_data_ptr <T>());
254- auto update_seq_len = updateTensor.size (updateTensor.dim () - 2 );
255- for (size_t j = 0 ; j < n_heads_per_cache_; j++) {
256- auto * update_head = output_ptrs_[i] + update_seq_len * head_dim_ * j;
257- auto * cache_head = input_ptrs_[i] + cache_len_ * head_dim_ * j;
258- std::copy (
259- update_head + update_pos * head_dim_,
260- update_head + (update_pos + update_len) * head_dim_,
261- cache_head + valid_len_ * head_dim_);
262- }
263- }
264- valid_len_ += update_len;
265- }
266-
267201 size_t n_caches_;
268202 size_t cache_len_;
269203 size_t max_input_len_;
270204 size_t n_heads_per_cache_;
271205 size_t head_dim_;
272- bool transpose_;
273206 StaticAttentionUpdateStyle style_;
274207 AllocatorT allocator_;
275208 size_t cache_data_size_;
@@ -301,7 +234,7 @@ class StaticAttentionMask {
301234 T zero_val,
302235 T mask_val,
303236 StaticAttentionUpdateStyle style =
304- StaticAttentionUpdateStyle::SLIDING_CACHE )
237+ StaticAttentionUpdateStyle::SMART_MASK )
305238 : cache_len_(cache_len),
306239 input_len_ (input_len),
307240 head_dim_(head_dim),
@@ -341,20 +274,10 @@ class StaticAttentionMask {
341274 * prefilling with padded inputs.
342275 */
343276 void unmask (size_t update_len) {
344- if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE) {
345- for (size_t i = 0 ; i < input_len_; i++) {
346- auto * p = data_ + (cache_len_ + input_len_) * i;
347- std::fill (
348- p + cache_len_ - cache_valid_len_ - update_len,
349- p + cache_len_ - cache_valid_len_,
350- zero_val_);
351- }
352- } else {
353- for (size_t i = 0 ; i < input_len_; i++) {
354- auto * p = data_ + (cache_len_ + input_len_) * i;
355- std::fill (
356- p + cache_valid_len_, p + cache_valid_len_ + update_len, zero_val_);
357- }
277+ for (size_t i = 0 ; i < input_len_; i++) {
278+ auto * p = data_ + (cache_len_ + input_len_) * i;
279+ std::fill (
280+ p + cache_valid_len_, p + cache_valid_len_ + update_len, zero_val_);
358281 }
359282 cache_valid_len_ += update_len;
360283 }
@@ -469,7 +392,7 @@ class StaticAttentionIOManager {
469392 RopeT* rope_freqs_cos;
470393 RopeT* rope_freqs_sin;
471394 StaticAttentionUpdateStyle style =
472- StaticAttentionUpdateStyle::SLIDING_CACHE ;
395+ StaticAttentionUpdateStyle::SMART_MASK ;
473396 };
474397
475398 StaticAttentionIOManager (StaticAttentionIOConfig config)
@@ -480,15 +403,13 @@ class StaticAttentionIOManager {
480403 config_.head_dim,
481404 config_.max_input_len,
482405 config_.n_heads_per_cache,
483- false ,
484406 config_.style),
485407 vCaches_(
486408 config_.n_caches,
487409 config_.cache_len,
488410 config_.head_dim,
489411 config_.max_input_len,
490412 config_.n_heads_per_cache,
491- false ,
492413 config_.style) {
493414 ET_LOG (
494415 Info,
0 commit comments