@@ -48,12 +48,14 @@ class StaticKVCache {
48
48
size_t cache_len,
49
49
size_t head_dim,
50
50
size_t max_input_len = 1 ,
51
+ size_t n_heads_per_cache = 1 ,
51
52
bool transpose = false ,
52
53
StaticAttentionUpdateStyle style =
53
54
StaticAttentionUpdateStyle::SLIDING_CACHE)
54
55
: n_caches_(n_caches),
55
56
cache_len_ (cache_len),
56
57
max_input_len_(max_input_len),
58
+ n_heads_per_cache_(n_heads_per_cache),
57
59
head_dim_(head_dim),
58
60
transpose_(transpose),
59
61
style_(style),
@@ -63,13 +65,21 @@ class StaticKVCache {
63
65
throw std::runtime_error (" Not implemented." );
64
66
}
65
67
68
+ if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE &&
69
+ n_heads_per_cache_ > 1 ) {
70
+ throw std::runtime_error (
71
+ " Sliding cache update strategy doesn't support more than one head per cache tensor." );
72
+ }
73
+
66
74
if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE) {
67
75
// Allocates on extra copy to accomodate caches sliding forward.
68
76
cache_data_size_ = (n_caches_ + 1 ) * cache_len_ * head_dim_;
69
77
} else {
70
- cache_data_size_ = n_caches_ * cache_len_ * head_dim_;
78
+ cache_data_size_ =
79
+ n_caches_ * n_heads_per_cache_ * cache_len_ * head_dim_;
71
80
}
72
- update_data_size_ = n_caches_ * max_input_len_ * head_dim_;
81
+ update_data_size_ =
82
+ n_caches_ * n_heads_per_cache_ * max_input_len_ * head_dim_;
73
83
74
84
cache_data_ = allocator_.allocate (cache_data_size_);
75
85
update_data_ = allocator_.allocate (update_data_size_);
@@ -111,14 +121,40 @@ class StaticKVCache {
111
121
auto outSizes = outMeta->sizes ();
112
122
ET_CHECK_MSG (inSizes[0 ] == 1 , " Only support batch size 1." );
113
123
ET_CHECK_MSG (outSizes[0 ] == 1 , " Only support batch size 1." );
124
+ if (n_heads_per_cache_ > 1 ) {
125
+ // More than 1 head per cache, meaning regular MHA is used. Tensor shape
126
+ // is (1, n_heads, seq_len, head_dim).
127
+ ET_CHECK_MSG (
128
+ inSizes.size () == 4 , " Cache input tensor expected to have rank 4." );
129
+ ET_CHECK_MSG (
130
+ outSizes.size () == 4 ,
131
+ " Cache input tensor expected to have rank 4." );
132
+ ET_CHECK_MSG (
133
+ inSizes[1 ] == n_heads_per_cache_,
134
+ " Number of heads per cache mismatch." );
135
+ ET_CHECK_MSG (
136
+ outSizes[1 ] == n_heads_per_cache_,
137
+ " Number of heads per cache mismatch." );
138
+ } else {
139
+ // 1 head per cache, meaning MHA is split up into multiple SHAs for QNN.
140
+ // Tensor shape is (1, seq_len, head_dim).
141
+ ET_CHECK_MSG (
142
+ inSizes.size () == 3 , " Cache input tensor expected to have rank 3." );
143
+ ET_CHECK_MSG (
144
+ outSizes.size () == 3 ,
145
+ " Cache input tensor expected to have rank 3." );
146
+ }
147
+ auto ndim = inSizes.size ();
114
148
if (transpose_) {
115
- ET_CHECK_MSG (inSizes[1 ] == head_dim_, " KV head dim mismatch." );
116
- ET_CHECK_MSG (outSizes[1 ] == head_dim_, " KV head dim mismatch." );
117
- ET_CHECK_MSG (inSizes[2 ] == cache_len_, " Cache length dim mismatch." );
149
+ ET_CHECK_MSG (inSizes[ndim - 2 ] == head_dim_, " KV head dim mismatch." );
150
+ ET_CHECK_MSG (outSizes[ndim - 2 ] == head_dim_, " KV head dim mismatch." );
151
+ ET_CHECK_MSG (
152
+ inSizes[ndim - 1 ] == cache_len_, " Cache length dim mismatch." );
118
153
} else {
119
- ET_CHECK_MSG (inSizes[2 ] == head_dim_, " KV head dim mismatch." );
120
- ET_CHECK_MSG (outSizes[2 ] == head_dim_, " KV head dim mismatch." );
121
- ET_CHECK_MSG (inSizes[1 ] == cache_len_, " Cache length dim mismatch." );
154
+ ET_CHECK_MSG (inSizes[ndim - 1 ] == head_dim_, " KV head dim mismatch." );
155
+ ET_CHECK_MSG (outSizes[ndim - 1 ] == head_dim_, " KV head dim mismatch." );
156
+ ET_CHECK_MSG (
157
+ inSizes[ndim - 2 ] == cache_len_, " Cache length dim mismatch." );
122
158
}
123
159
124
160
auto impl = ::executorch::runtime::etensor::TensorImpl (
@@ -180,8 +216,10 @@ class StaticKVCache {
180
216
input_ptrs_.resize (n_caches_);
181
217
output_ptrs_.resize (n_caches_);
182
218
for (size_t i = 0 ; i < n_caches_; i++) {
183
- input_ptrs_[i] = cache_data_ + i * cache_len_ * head_dim_;
184
- output_ptrs_[i] = update_data_ + i * max_input_len_ * head_dim_;
219
+ input_ptrs_[i] =
220
+ cache_data_ + i * n_heads_per_cache_ * cache_len_ * head_dim_;
221
+ output_ptrs_[i] =
222
+ update_data_ + i * n_heads_per_cache_ * max_input_len_ * head_dim_;
185
223
}
186
224
}
187
225
@@ -213,17 +251,23 @@ class StaticKVCache {
213
251
const auto & updateTensor =
214
252
method.get_output (output_indices[i]).toTensor ();
215
253
ET_CHECK (output_ptrs_[i] == updateTensor.mutable_data_ptr <T>());
216
- std::copy (
217
- output_ptrs_[i] + update_pos * head_dim_,
218
- output_ptrs_[i] + (update_pos + update_len) * head_dim_,
219
- input_ptrs_[i] + valid_len_ * head_dim_);
254
+ auto update_seq_len = updateTensor.size (updateTensor.dim () - 2 );
255
+ for (size_t j = 0 ; j < n_heads_per_cache_; j++) {
256
+ auto * update_head = output_ptrs_[i] + update_seq_len * head_dim_ * j;
257
+ auto * cache_head = input_ptrs_[i] + cache_len_ * head_dim_ * j;
258
+ std::copy (
259
+ update_head + update_pos * head_dim_,
260
+ update_head + (update_pos + update_len) * head_dim_,
261
+ cache_head + valid_len_ * head_dim_);
262
+ }
220
263
}
221
264
valid_len_ += update_len;
222
265
}
223
266
224
267
size_t n_caches_;
225
268
size_t cache_len_;
226
269
size_t max_input_len_;
270
+ size_t n_heads_per_cache_;
227
271
size_t head_dim_;
228
272
bool transpose_;
229
273
StaticAttentionUpdateStyle style_;
@@ -414,6 +458,7 @@ class StaticAttentionIOManager {
414
458
size_t cache_len{};
415
459
size_t head_dim{};
416
460
size_t max_input_len{};
461
+ size_t n_heads_per_cache{};
417
462
size_t attn_mask_input_index{};
418
463
size_t rope_freqs_cos_input_index{};
419
464
size_t rope_freqs_sin_input_index{};
@@ -434,13 +479,15 @@ class StaticAttentionIOManager {
434
479
config_.cache_len,
435
480
config_.head_dim,
436
481
config_.max_input_len,
482
+ config_.n_heads_per_cache,
437
483
false ,
438
484
config_.style),
439
485
vCaches_(
440
486
config_.n_caches,
441
487
config_.cache_len,
442
488
config_.head_dim,
443
489
config_.max_input_len,
490
+ config_.n_heads_per_cache,
444
491
false ,
445
492
config_.style) {
446
493
ET_LOG (
0 commit comments