19
19
namespace example {
20
20
21
21
enum class StaticAttentionUpdateStyle {
22
- /* *
23
- * KV caches will have valid data at the end of the cache. New elements are
24
- * added at the end and the start of the cache will slide forward to maintain
25
- * this invariant. This potentially allows shorter caches to be passed into
26
- * the model by adjusting the start pointer.
27
- */
28
- SLIDING_CACHE,
29
22
/* *
30
23
* I/O pointers do not change which can enable persistent memory mapping
31
- * between AP and NPU.
24
+ * between AP and NPU. However cache updates need to be copied.
32
25
*/
33
26
SMART_MASK,
34
27
};
@@ -49,35 +42,16 @@ class StaticKVCache {
49
42
size_t head_dim,
50
43
size_t max_input_len = 1 ,
51
44
size_t n_heads_per_cache = 1 ,
52
- bool transpose = false ,
53
- StaticAttentionUpdateStyle style =
54
- StaticAttentionUpdateStyle::SLIDING_CACHE)
45
+ StaticAttentionUpdateStyle style = StaticAttentionUpdateStyle::SMART_MASK)
55
46
: n_caches_(n_caches),
56
47
cache_len_ (cache_len),
57
48
max_input_len_(max_input_len),
58
49
n_heads_per_cache_(n_heads_per_cache),
59
50
head_dim_(head_dim),
60
- transpose_(transpose),
61
51
style_(style),
62
52
input_ptrs_(n_caches_),
63
53
output_ptrs_(n_caches_) {
64
- if (transpose_) {
65
- throw std::runtime_error (" Not implemented." );
66
- }
67
-
68
- if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE &&
69
- n_heads_per_cache_ > 1 ) {
70
- throw std::runtime_error (
71
- " Sliding cache update strategy doesn't support more than one head per cache tensor." );
72
- }
73
-
74
- if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE) {
75
- // Allocates on extra copy to accomodate caches sliding forward.
76
- cache_data_size_ = (n_caches_ + 1 ) * cache_len_ * head_dim_;
77
- } else {
78
- cache_data_size_ =
79
- n_caches_ * n_heads_per_cache_ * cache_len_ * head_dim_;
80
- }
54
+ cache_data_size_ = n_caches_ * n_heads_per_cache_ * cache_len_ * head_dim_;
81
55
update_data_size_ =
82
56
n_caches_ * n_heads_per_cache_ * max_input_len_ * head_dim_;
83
57
@@ -145,17 +119,10 @@ class StaticKVCache {
145
119
" Cache input tensor expected to have rank 3." );
146
120
}
147
121
auto ndim = inSizes.size ();
148
- if (transpose_) {
149
- ET_CHECK_MSG (inSizes[ndim - 2 ] == head_dim_, " KV head dim mismatch." );
150
- ET_CHECK_MSG (outSizes[ndim - 2 ] == head_dim_, " KV head dim mismatch." );
151
- ET_CHECK_MSG (
152
- inSizes[ndim - 1 ] == cache_len_, " Cache length dim mismatch." );
153
- } else {
154
- ET_CHECK_MSG (inSizes[ndim - 1 ] == head_dim_, " KV head dim mismatch." );
155
- ET_CHECK_MSG (outSizes[ndim - 1 ] == head_dim_, " KV head dim mismatch." );
156
- ET_CHECK_MSG (
157
- inSizes[ndim - 2 ] == cache_len_, " Cache length dim mismatch." );
158
- }
122
+ ET_CHECK_MSG (inSizes[ndim - 1 ] == head_dim_, " KV head dim mismatch." );
123
+ ET_CHECK_MSG (outSizes[ndim - 1 ] == head_dim_, " KV head dim mismatch." );
124
+ ET_CHECK_MSG (
125
+ inSizes[ndim - 2 ] == cache_len_, " Cache length dim mismatch." );
159
126
160
127
auto impl = ::executorch::runtime::etensor::TensorImpl (
161
128
inMeta->scalar_type (),
@@ -189,11 +156,21 @@ class StaticKVCache {
189
156
throw std::runtime_error (" Cache capacity exceeded." );
190
157
}
191
158
192
- if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE) {
193
- update_sliding_cache (method, output_indices, update_len, update_pos);
194
- } else {
195
- update_smart_mask (method, output_indices, update_len, update_pos);
159
+ for (size_t i = 0 ; i < n_caches_; i++) {
160
+ const auto & updateTensor =
161
+ method.get_output (output_indices[i]).toTensor ();
162
+ ET_CHECK (output_ptrs_[i] == updateTensor.mutable_data_ptr <T>());
163
+ auto update_seq_len = updateTensor.size (updateTensor.dim () - 2 );
164
+ for (size_t j = 0 ; j < n_heads_per_cache_; j++) {
165
+ auto * update_head = output_ptrs_[i] + update_seq_len * head_dim_ * j;
166
+ auto * cache_head = input_ptrs_[i] + cache_len_ * head_dim_ * j;
167
+ std::copy (
168
+ update_head + update_pos * head_dim_,
169
+ update_head + (update_pos + update_len) * head_dim_,
170
+ cache_head + valid_len_ * head_dim_);
171
+ }
196
172
}
173
+ valid_len_ += update_len;
197
174
}
198
175
199
176
/* *
@@ -202,9 +179,6 @@ class StaticKVCache {
202
179
*/
203
180
void reset () {
204
181
valid_len_ = 0 ;
205
- if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE) {
206
- init_ptrs ();
207
- }
208
182
}
209
183
210
184
size_t size () {
@@ -223,53 +197,11 @@ class StaticKVCache {
223
197
}
224
198
}
225
199
226
- void update_sliding_cache (
227
- torch::executor::Method& method,
228
- const std::vector<size_t >& output_indices,
229
- size_t update_len,
230
- size_t update_pos) {
231
- ET_CHECK (n_caches_ == output_indices.size ());
232
- for (size_t i = 0 ; i < n_caches_; i++) {
233
- const auto & updateTensor =
234
- method.get_output (output_indices[i]).toTensor ();
235
- ET_CHECK (output_ptrs_[i] == updateTensor.const_data_ptr <T>());
236
- std::copy (
237
- output_ptrs_[i] + update_pos * head_dim_,
238
- output_ptrs_[i] + (update_pos + update_len) * head_dim_,
239
- input_ptrs_[i] + cache_len_ * head_dim_);
240
- input_ptrs_[i] += update_len * head_dim_;
241
- }
242
- valid_len_ += update_len;
243
- }
244
-
245
- void update_smart_mask (
246
- torch::executor::Method& method,
247
- const std::vector<size_t >& output_indices,
248
- size_t update_len,
249
- size_t update_pos) {
250
- for (size_t i = 0 ; i < n_caches_; i++) {
251
- const auto & updateTensor =
252
- method.get_output (output_indices[i]).toTensor ();
253
- ET_CHECK (output_ptrs_[i] == updateTensor.mutable_data_ptr <T>());
254
- auto update_seq_len = updateTensor.size (updateTensor.dim () - 2 );
255
- for (size_t j = 0 ; j < n_heads_per_cache_; j++) {
256
- auto * update_head = output_ptrs_[i] + update_seq_len * head_dim_ * j;
257
- auto * cache_head = input_ptrs_[i] + cache_len_ * head_dim_ * j;
258
- std::copy (
259
- update_head + update_pos * head_dim_,
260
- update_head + (update_pos + update_len) * head_dim_,
261
- cache_head + valid_len_ * head_dim_);
262
- }
263
- }
264
- valid_len_ += update_len;
265
- }
266
-
267
200
size_t n_caches_;
268
201
size_t cache_len_;
269
202
size_t max_input_len_;
270
203
size_t n_heads_per_cache_;
271
204
size_t head_dim_;
272
- bool transpose_;
273
205
StaticAttentionUpdateStyle style_;
274
206
AllocatorT allocator_;
275
207
size_t cache_data_size_;
@@ -300,8 +232,7 @@ class StaticAttentionMask {
300
232
size_t head_dim,
301
233
T zero_val,
302
234
T mask_val,
303
- StaticAttentionUpdateStyle style =
304
- StaticAttentionUpdateStyle::SLIDING_CACHE)
235
+ StaticAttentionUpdateStyle style = StaticAttentionUpdateStyle::SMART_MASK)
305
236
: cache_len_(cache_len),
306
237
input_len_ (input_len),
307
238
head_dim_(head_dim),
@@ -341,20 +272,10 @@ class StaticAttentionMask {
341
272
* prefilling with padded inputs.
342
273
*/
343
274
void unmask (size_t update_len) {
344
- if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE) {
345
- for (size_t i = 0 ; i < input_len_; i++) {
346
- auto * p = data_ + (cache_len_ + input_len_) * i;
347
- std::fill (
348
- p + cache_len_ - cache_valid_len_ - update_len,
349
- p + cache_len_ - cache_valid_len_,
350
- zero_val_);
351
- }
352
- } else {
353
- for (size_t i = 0 ; i < input_len_; i++) {
354
- auto * p = data_ + (cache_len_ + input_len_) * i;
355
- std::fill (
356
- p + cache_valid_len_, p + cache_valid_len_ + update_len, zero_val_);
357
- }
275
+ for (size_t i = 0 ; i < input_len_; i++) {
276
+ auto * p = data_ + (cache_len_ + input_len_) * i;
277
+ std::fill (
278
+ p + cache_valid_len_, p + cache_valid_len_ + update_len, zero_val_);
358
279
}
359
280
cache_valid_len_ += update_len;
360
281
}
@@ -468,8 +389,7 @@ class StaticAttentionIOManager {
468
389
std::vector<size_t > v_cache_output_indices;
469
390
RopeT* rope_freqs_cos;
470
391
RopeT* rope_freqs_sin;
471
- StaticAttentionUpdateStyle style =
472
- StaticAttentionUpdateStyle::SLIDING_CACHE;
392
+ StaticAttentionUpdateStyle style = StaticAttentionUpdateStyle::SMART_MASK;
473
393
};
474
394
475
395
StaticAttentionIOManager (StaticAttentionIOConfig config)
@@ -480,15 +400,13 @@ class StaticAttentionIOManager {
480
400
config_.head_dim,
481
401
config_.max_input_len,
482
402
config_.n_heads_per_cache,
483
- false ,
484
403
config_.style),
485
404
vCaches_(
486
405
config_.n_caches,
487
406
config_.cache_len,
488
407
config_.head_dim,
489
408
config_.max_input_len,
490
409
config_.n_heads_per_cache,
491
- false ,
492
410
config_.style) {
493
411
ET_LOG (
494
412
Info,
0 commit comments