Skip to content

Commit 6243e01

Browse files
sxufacebook-github-bot
authored andcommitted
StaticAttentionIOManager: support more than 1 head per cache tensor
Summary: CoreML model will not split MHA up, so it will have more than 1 head per cache tensor. Note only the smart mask update style supports this. Differential Revision: D78669353
1 parent 0fbd6d4 commit 6243e01

File tree

1 file changed

+57
-14
lines changed

1 file changed

+57
-14
lines changed

examples/models/llama/runner/static_attention_io_manager.h

Lines changed: 57 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,14 @@ class StaticKVCache {
4848
size_t cache_len,
4949
size_t head_dim,
5050
size_t max_input_len = 1,
51+
size_t n_heads_per_cache = 1,
5152
bool transpose = false,
5253
StaticAttentionUpdateStyle style =
5354
StaticAttentionUpdateStyle::SLIDING_CACHE)
5455
: n_caches_(n_caches),
5556
cache_len_(cache_len),
5657
max_input_len_(max_input_len),
58+
n_heads_per_cache_(n_heads_per_cache),
5759
head_dim_(head_dim),
5860
transpose_(transpose),
5961
style_(style),
@@ -63,13 +65,21 @@ class StaticKVCache {
6365
throw std::runtime_error("Not implemented.");
6466
}
6567

68+
if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE &&
69+
n_heads_per_cache_ > 1) {
70+
throw std::runtime_error(
71+
"Sliding cache update strategy doesn't support more than one head per cache tensor.");
72+
}
73+
6674
if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE) {
6775
// Allocates on extra copy to accomodate caches sliding forward.
6876
cache_data_size_ = (n_caches_ + 1) * cache_len_ * head_dim_;
6977
} else {
70-
cache_data_size_ = n_caches_ * cache_len_ * head_dim_;
78+
cache_data_size_ =
79+
n_caches_ * n_heads_per_cache_ * cache_len_ * head_dim_;
7180
}
72-
update_data_size_ = n_caches_ * max_input_len_ * head_dim_;
81+
update_data_size_ =
82+
n_caches_ * n_heads_per_cache_ * max_input_len_ * head_dim_;
7383

7484
cache_data_ = allocator_.allocate(cache_data_size_);
7585
update_data_ = allocator_.allocate(update_data_size_);
@@ -111,14 +121,36 @@ class StaticKVCache {
111121
auto outSizes = outMeta->sizes();
112122
ET_CHECK_MSG(inSizes[0] == 1, "Only support batch size 1.");
113123
ET_CHECK_MSG(outSizes[0] == 1, "Only support batch size 1.");
124+
if (n_heads_per_cache_ > 1) {
125+
ET_CHECK_MSG(
126+
inSizes.size() == 4, "Cache input tensor expected to have rank 4.");
127+
ET_CHECK_MSG(
128+
outSizes.size() == 4,
129+
"Cache input tensor expected to have rank 4.");
130+
ET_CHECK_MSG(
131+
inSizes[1] == n_heads_per_cache_,
132+
"Number of heads per cache mismatch.");
133+
ET_CHECK_MSG(
134+
outSizes[1] == n_heads_per_cache_,
135+
"Number of heads per cache mismatch.");
136+
} else {
137+
ET_CHECK_MSG(
138+
inSizes.size() == 3, "Cache input tensor expected to have rank 3.");
139+
ET_CHECK_MSG(
140+
outSizes.size() == 3,
141+
"Cache input tensor expected to have rank 3.");
142+
}
143+
auto ndim = inSizes.size();
114144
if (transpose_) {
115-
ET_CHECK_MSG(inSizes[1] == head_dim_, "KV head dim mismatch.");
116-
ET_CHECK_MSG(outSizes[1] == head_dim_, "KV head dim mismatch.");
117-
ET_CHECK_MSG(inSizes[2] == cache_len_, "Cache length dim mismatch.");
145+
ET_CHECK_MSG(inSizes[ndim - 2] == head_dim_, "KV head dim mismatch.");
146+
ET_CHECK_MSG(outSizes[ndim - 2] == head_dim_, "KV head dim mismatch.");
147+
ET_CHECK_MSG(
148+
inSizes[ndim - 1] == cache_len_, "Cache length dim mismatch.");
118149
} else {
119-
ET_CHECK_MSG(inSizes[2] == head_dim_, "KV head dim mismatch.");
120-
ET_CHECK_MSG(outSizes[2] == head_dim_, "KV head dim mismatch.");
121-
ET_CHECK_MSG(inSizes[1] == cache_len_, "Cache length dim mismatch.");
150+
ET_CHECK_MSG(inSizes[ndim - 1] == head_dim_, "KV head dim mismatch.");
151+
ET_CHECK_MSG(outSizes[ndim - 1] == head_dim_, "KV head dim mismatch.");
152+
ET_CHECK_MSG(
153+
inSizes[ndim - 2] == cache_len_, "Cache length dim mismatch.");
122154
}
123155

124156
auto impl = ::executorch::runtime::etensor::TensorImpl(
@@ -180,8 +212,10 @@ class StaticKVCache {
180212
input_ptrs_.resize(n_caches_);
181213
output_ptrs_.resize(n_caches_);
182214
for (size_t i = 0; i < n_caches_; i++) {
183-
input_ptrs_[i] = cache_data_ + i * cache_len_ * head_dim_;
184-
output_ptrs_[i] = update_data_ + i * max_input_len_ * head_dim_;
215+
input_ptrs_[i] =
216+
cache_data_ + i * n_heads_per_cache_ * cache_len_ * head_dim_;
217+
output_ptrs_[i] =
218+
update_data_ + i * n_heads_per_cache_ * max_input_len_ * head_dim_;
185219
}
186220
}
187221

@@ -213,17 +247,23 @@ class StaticKVCache {
213247
const auto& updateTensor =
214248
method.get_output(output_indices[i]).toTensor();
215249
ET_CHECK(output_ptrs_[i] == updateTensor.mutable_data_ptr<T>());
216-
std::copy(
217-
output_ptrs_[i] + update_pos * head_dim_,
218-
output_ptrs_[i] + (update_pos + update_len) * head_dim_,
219-
input_ptrs_[i] + valid_len_ * head_dim_);
250+
auto update_seq_len = updateTensor.size(updateTensor.dim() - 2);
251+
for (size_t j = 0; j < n_heads_per_cache_; j++) {
252+
auto* update_head = output_ptrs_[i] + update_seq_len * head_dim_ * j;
253+
auto* cache_head = input_ptrs_[i] + cache_len_ * head_dim_ * j;
254+
std::copy(
255+
update_head + update_pos * head_dim_,
256+
update_head + (update_pos + update_len) * head_dim_,
257+
cache_head + valid_len_ * head_dim_);
258+
}
220259
}
221260
valid_len_ += update_len;
222261
}
223262

224263
size_t n_caches_;
225264
size_t cache_len_;
226265
size_t max_input_len_;
266+
size_t n_heads_per_cache_;
227267
size_t head_dim_;
228268
bool transpose_;
229269
StaticAttentionUpdateStyle style_;
@@ -414,6 +454,7 @@ class StaticAttentionIOManager {
414454
size_t cache_len{};
415455
size_t head_dim{};
416456
size_t max_input_len{};
457+
size_t n_heads_per_cache{};
417458
size_t attn_mask_input_index{};
418459
size_t rope_freqs_cos_input_index{};
419460
size_t rope_freqs_sin_input_index{};
@@ -434,13 +475,15 @@ class StaticAttentionIOManager {
434475
config_.cache_len,
435476
config_.head_dim,
436477
config_.max_input_len,
478+
config_.n_heads_per_cache,
437479
false,
438480
config_.style),
439481
vCaches_(
440482
config_.n_caches,
441483
config_.cache_len,
442484
config_.head_dim,
443485
config_.max_input_len,
486+
config_.n_heads_per_cache,
444487
false,
445488
config_.style) {
446489
ET_LOG(

0 commit comments

Comments
 (0)