@@ -48,12 +48,14 @@ class StaticKVCache {
4848 size_t cache_len,
4949 size_t head_dim,
5050 size_t max_input_len = 1 ,
51+ size_t n_heads_per_cache = 1 ,
5152 bool transpose = false ,
5253 StaticAttentionUpdateStyle style =
5354 StaticAttentionUpdateStyle::SLIDING_CACHE)
5455 : n_caches_(n_caches),
5556 cache_len_ (cache_len),
5657 max_input_len_(max_input_len),
58+ n_heads_per_cache_(n_heads_per_cache),
5759 head_dim_(head_dim),
5860 transpose_(transpose),
5961 style_(style),
@@ -63,13 +65,21 @@ class StaticKVCache {
6365 throw std::runtime_error (" Not implemented." );
6466 }
6567
68+ if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE &&
69+ n_heads_per_cache_ > 1 ) {
70+ throw std::runtime_error (
71+ " Sliding cache update strategy doesn't support more than one head per cache tensor." );
72+ }
73+
6674 if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE) {
6775 // Allocates on extra copy to accomodate caches sliding forward.
6876 cache_data_size_ = (n_caches_ + 1 ) * cache_len_ * head_dim_;
6977 } else {
70- cache_data_size_ = n_caches_ * cache_len_ * head_dim_;
78+ cache_data_size_ =
79+ n_caches_ * n_heads_per_cache_ * cache_len_ * head_dim_;
7180 }
72- update_data_size_ = n_caches_ * max_input_len_ * head_dim_;
81+ update_data_size_ =
82+ n_caches_ * n_heads_per_cache_ * max_input_len_ * head_dim_;
7383
7484 cache_data_ = allocator_.allocate (cache_data_size_);
7585 update_data_ = allocator_.allocate (update_data_size_);
@@ -111,14 +121,36 @@ class StaticKVCache {
111121 auto outSizes = outMeta->sizes ();
112122 ET_CHECK_MSG (inSizes[0 ] == 1 , " Only support batch size 1." );
113123 ET_CHECK_MSG (outSizes[0 ] == 1 , " Only support batch size 1." );
124+ if (n_heads_per_cache_ > 1 ) {
125+ ET_CHECK_MSG (
126+ inSizes.size () == 4 , " Cache input tensor expected to have rank 4." );
127+ ET_CHECK_MSG (
128+ outSizes.size () == 4 ,
129+ " Cache input tensor expected to have rank 4." );
130+ ET_CHECK_MSG (
131+ inSizes[1 ] == n_heads_per_cache_,
132+ " Number of heads per cache mismatch." );
133+ ET_CHECK_MSG (
134+ outSizes[1 ] == n_heads_per_cache_,
135+ " Number of heads per cache mismatch." );
136+ } else {
137+ ET_CHECK_MSG (
138+ inSizes.size () == 3 , " Cache input tensor expected to have rank 3." );
139+ ET_CHECK_MSG (
140+ outSizes.size () == 3 ,
141+ " Cache input tensor expected to have rank 3." );
142+ }
143+ auto ndim = inSizes.size ();
114144 if (transpose_) {
115- ET_CHECK_MSG (inSizes[1 ] == head_dim_, " KV head dim mismatch." );
116- ET_CHECK_MSG (outSizes[1 ] == head_dim_, " KV head dim mismatch." );
117- ET_CHECK_MSG (inSizes[2 ] == cache_len_, " Cache length dim mismatch." );
145+ ET_CHECK_MSG (inSizes[ndim - 2 ] == head_dim_, " KV head dim mismatch." );
146+ ET_CHECK_MSG (outSizes[ndim - 2 ] == head_dim_, " KV head dim mismatch." );
147+ ET_CHECK_MSG (
148+ inSizes[ndim - 1 ] == cache_len_, " Cache length dim mismatch." );
118149 } else {
119- ET_CHECK_MSG (inSizes[2 ] == head_dim_, " KV head dim mismatch." );
120- ET_CHECK_MSG (outSizes[2 ] == head_dim_, " KV head dim mismatch." );
121- ET_CHECK_MSG (inSizes[1 ] == cache_len_, " Cache length dim mismatch." );
150+ ET_CHECK_MSG (inSizes[ndim - 1 ] == head_dim_, " KV head dim mismatch." );
151+ ET_CHECK_MSG (outSizes[ndim - 1 ] == head_dim_, " KV head dim mismatch." );
152+ ET_CHECK_MSG (
153+ inSizes[ndim - 2 ] == cache_len_, " Cache length dim mismatch." );
122154 }
123155
124156 auto impl = ::executorch::runtime::etensor::TensorImpl (
@@ -180,8 +212,10 @@ class StaticKVCache {
180212 input_ptrs_.resize (n_caches_);
181213 output_ptrs_.resize (n_caches_);
182214 for (size_t i = 0 ; i < n_caches_; i++) {
183- input_ptrs_[i] = cache_data_ + i * cache_len_ * head_dim_;
184- output_ptrs_[i] = update_data_ + i * max_input_len_ * head_dim_;
215+ input_ptrs_[i] =
216+ cache_data_ + i * n_heads_per_cache_ * cache_len_ * head_dim_;
217+ output_ptrs_[i] =
218+ update_data_ + i * n_heads_per_cache_ * max_input_len_ * head_dim_;
185219 }
186220 }
187221
@@ -213,17 +247,23 @@ class StaticKVCache {
213247 const auto & updateTensor =
214248 method.get_output (output_indices[i]).toTensor ();
215249 ET_CHECK (output_ptrs_[i] == updateTensor.mutable_data_ptr <T>());
216- std::copy (
217- output_ptrs_[i] + update_pos * head_dim_,
218- output_ptrs_[i] + (update_pos + update_len) * head_dim_,
219- input_ptrs_[i] + valid_len_ * head_dim_);
250+ auto update_seq_len = updateTensor.size (updateTensor.dim () - 2 );
251+ for (size_t j = 0 ; j < n_heads_per_cache_; j++) {
252+ auto * update_head = output_ptrs_[i] + update_seq_len * head_dim_ * j;
253+ auto * cache_head = input_ptrs_[i] + cache_len_ * head_dim_ * j;
254+ std::copy (
255+ update_head + update_pos * head_dim_,
256+ update_head + (update_pos + update_len) * head_dim_,
257+ cache_head + valid_len_ * head_dim_);
258+ }
220259 }
221260 valid_len_ += update_len;
222261 }
223262
224263 size_t n_caches_;
225264 size_t cache_len_;
226265 size_t max_input_len_;
266+ size_t n_heads_per_cache_;
227267 size_t head_dim_;
228268 bool transpose_;
229269 StaticAttentionUpdateStyle style_;
@@ -414,6 +454,7 @@ class StaticAttentionIOManager {
414454 size_t cache_len{};
415455 size_t head_dim{};
416456 size_t max_input_len{};
457+ size_t n_heads_per_cache{};
417458 size_t attn_mask_input_index{};
418459 size_t rope_freqs_cos_input_index{};
419460 size_t rope_freqs_sin_input_index{};
@@ -434,13 +475,15 @@ class StaticAttentionIOManager {
434475 config_.cache_len,
435476 config_.head_dim,
436477 config_.max_input_len,
478+ config_.n_heads_per_cache,
437479 false ,
438480 config_.style),
439481 vCaches_(
440482 config_.n_caches,
441483 config_.cache_len,
442484 config_.head_dim,
443485 config_.max_input_len,
486+ config_.n_heads_per_cache,
444487 false ,
445488 config_.style) {
446489 ET_LOG (
0 commit comments