@@ -48,12 +48,14 @@ class StaticKVCache {
4848 size_t cache_len,
4949 size_t head_dim,
5050 size_t max_input_len = 1 ,
51+ size_t n_heads_per_cache = 1 ,
5152 bool transpose = false ,
5253 StaticAttentionUpdateStyle style =
5354 StaticAttentionUpdateStyle::SLIDING_CACHE)
5455 : n_caches_(n_caches),
5556 cache_len_ (cache_len),
5657 max_input_len_(max_input_len),
58+ n_heads_per_cache_(n_heads_per_cache),
5759 head_dim_(head_dim),
5860 transpose_(transpose),
5961 style_(style),
@@ -63,13 +65,21 @@ class StaticKVCache {
6365 throw std::runtime_error (" Not implemented." );
6466 }
6567
68+ if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE &&
69+ n_heads_per_cache_ > 1 ) {
70+ throw std::runtime_error (
71+ " Sliding cache update strategy doesn't support more than one head per cache tensor." );
72+ }
73+
6674 if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE) {
6775 // Allocates on extra copy to accomodate caches sliding forward.
6876 cache_data_size_ = (n_caches_ + 1 ) * cache_len_ * head_dim_;
6977 } else {
70- cache_data_size_ = n_caches_ * cache_len_ * head_dim_;
78+ cache_data_size_ =
79+ n_caches_ * n_heads_per_cache_ * cache_len_ * head_dim_;
7180 }
72- update_data_size_ = n_caches_ * max_input_len_ * head_dim_;
81+ update_data_size_ =
82+ n_caches_ * n_heads_per_cache_ * max_input_len_ * head_dim_;
7383
7484 cache_data_ = allocator_.allocate (cache_data_size_);
7585 update_data_ = allocator_.allocate (update_data_size_);
@@ -111,14 +121,40 @@ class StaticKVCache {
111121 auto outSizes = outMeta->sizes ();
112122 ET_CHECK_MSG (inSizes[0 ] == 1 , " Only support batch size 1." );
113123 ET_CHECK_MSG (outSizes[0 ] == 1 , " Only support batch size 1." );
124+ if (n_heads_per_cache_ > 1 ) {
125+ // More than 1 head per cache, meaning regular MHA is used. Tensor shape
126+ // is (1, n_heads, seq_len, head_dim).
127+ ET_CHECK_MSG (
128+ inSizes.size () == 4 , " Cache input tensor expected to have rank 4." );
129+ ET_CHECK_MSG (
130+ outSizes.size () == 4 ,
131+ " Cache input tensor expected to have rank 4." );
132+ ET_CHECK_MSG (
133+ inSizes[1 ] == n_heads_per_cache_,
134+ " Number of heads per cache mismatch." );
135+ ET_CHECK_MSG (
136+ outSizes[1 ] == n_heads_per_cache_,
137+ " Number of heads per cache mismatch." );
138+ } else {
139+ // 1 head per cache, meaning MHA is split up into multiple SHAs for QNN.
140+ // Tensor shape is (1, seq_len, head_dim).
141+ ET_CHECK_MSG (
142+ inSizes.size () == 3 , " Cache input tensor expected to have rank 3." );
143+ ET_CHECK_MSG (
144+ outSizes.size () == 3 ,
145+ " Cache input tensor expected to have rank 3." );
146+ }
147+ auto ndim = inSizes.size ();
114148 if (transpose_) {
115- ET_CHECK_MSG (inSizes[1 ] == head_dim_, " KV head dim mismatch." );
116- ET_CHECK_MSG (outSizes[1 ] == head_dim_, " KV head dim mismatch." );
117- ET_CHECK_MSG (inSizes[2 ] == cache_len_, " Cache length dim mismatch." );
149+ ET_CHECK_MSG (inSizes[ndim - 2 ] == head_dim_, " KV head dim mismatch." );
150+ ET_CHECK_MSG (outSizes[ndim - 2 ] == head_dim_, " KV head dim mismatch." );
151+ ET_CHECK_MSG (
152+ inSizes[ndim - 1 ] == cache_len_, " Cache length dim mismatch." );
118153 } else {
119- ET_CHECK_MSG (inSizes[2 ] == head_dim_, " KV head dim mismatch." );
120- ET_CHECK_MSG (outSizes[2 ] == head_dim_, " KV head dim mismatch." );
121- ET_CHECK_MSG (inSizes[1 ] == cache_len_, " Cache length dim mismatch." );
154+ ET_CHECK_MSG (inSizes[ndim - 1 ] == head_dim_, " KV head dim mismatch." );
155+ ET_CHECK_MSG (outSizes[ndim - 1 ] == head_dim_, " KV head dim mismatch." );
156+ ET_CHECK_MSG (
157+ inSizes[ndim - 2 ] == cache_len_, " Cache length dim mismatch." );
122158 }
123159
124160 auto impl = ::executorch::runtime::etensor::TensorImpl (
@@ -180,8 +216,10 @@ class StaticKVCache {
180216 input_ptrs_.resize (n_caches_);
181217 output_ptrs_.resize (n_caches_);
182218 for (size_t i = 0 ; i < n_caches_; i++) {
183- input_ptrs_[i] = cache_data_ + i * cache_len_ * head_dim_;
184- output_ptrs_[i] = update_data_ + i * max_input_len_ * head_dim_;
219+ input_ptrs_[i] =
220+ cache_data_ + i * n_heads_per_cache_ * cache_len_ * head_dim_;
221+ output_ptrs_[i] =
222+ update_data_ + i * n_heads_per_cache_ * max_input_len_ * head_dim_;
185223 }
186224 }
187225
@@ -213,17 +251,23 @@ class StaticKVCache {
213251 const auto & updateTensor =
214252 method.get_output (output_indices[i]).toTensor ();
215253 ET_CHECK (output_ptrs_[i] == updateTensor.mutable_data_ptr <T>());
216- std::copy (
217- output_ptrs_[i] + update_pos * head_dim_,
218- output_ptrs_[i] + (update_pos + update_len) * head_dim_,
219- input_ptrs_[i] + valid_len_ * head_dim_);
254+ auto update_seq_len = updateTensor.size (updateTensor.dim () - 2 );
255+ for (size_t j = 0 ; j < n_heads_per_cache_; j++) {
256+ auto * update_head = output_ptrs_[i] + update_seq_len * head_dim_ * j;
257+ auto * cache_head = input_ptrs_[i] + cache_len_ * head_dim_ * j;
258+ std::copy (
259+ update_head + update_pos * head_dim_,
260+ update_head + (update_pos + update_len) * head_dim_,
261+ cache_head + valid_len_ * head_dim_);
262+ }
220263 }
221264 valid_len_ += update_len;
222265 }
223266
224267 size_t n_caches_;
225268 size_t cache_len_;
226269 size_t max_input_len_;
270+ size_t n_heads_per_cache_;
227271 size_t head_dim_;
228272 bool transpose_;
229273 StaticAttentionUpdateStyle style_;
@@ -414,6 +458,7 @@ class StaticAttentionIOManager {
414458 size_t cache_len{};
415459 size_t head_dim{};
416460 size_t max_input_len{};
461+ size_t n_heads_per_cache{};
417462 size_t attn_mask_input_index{};
418463 size_t rope_freqs_cos_input_index{};
419464 size_t rope_freqs_sin_input_index{};
@@ -434,13 +479,15 @@ class StaticAttentionIOManager {
434479 config_.cache_len,
435480 config_.head_dim,
436481 config_.max_input_len,
482+ config_.n_heads_per_cache,
437483 false ,
438484 config_.style),
439485 vCaches_(
440486 config_.n_caches,
441487 config_.cache_len,
442488 config_.head_dim,
443489 config_.max_input_len,
490+ config_.n_heads_per_cache,
444491 false ,
445492 config_.style) {
446493 ET_LOG (
0 commit comments