Skip to content

Commit 4df7223

Browse files
authored
StaticAttentionIOManager: support more than 1 head per cache tensor
Differential Revision: D78669353 Pull Request resolved: #12681
1 parent 6c2b94e commit 4df7223

File tree

1 file changed

+61
-14
lines changed

1 file changed

+61
-14
lines changed

examples/models/llama/runner/static_attention_io_manager.h

Lines changed: 61 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,14 @@ class StaticKVCache {
4848
size_t cache_len,
4949
size_t head_dim,
5050
size_t max_input_len = 1,
51+
size_t n_heads_per_cache = 1,
5152
bool transpose = false,
5253
StaticAttentionUpdateStyle style =
5354
StaticAttentionUpdateStyle::SLIDING_CACHE)
5455
: n_caches_(n_caches),
5556
cache_len_(cache_len),
5657
max_input_len_(max_input_len),
58+
n_heads_per_cache_(n_heads_per_cache),
5759
head_dim_(head_dim),
5860
transpose_(transpose),
5961
style_(style),
@@ -63,13 +65,21 @@ class StaticKVCache {
6365
throw std::runtime_error("Not implemented.");
6466
}
6567

68+
if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE &&
69+
n_heads_per_cache_ > 1) {
70+
throw std::runtime_error(
71+
"Sliding cache update strategy doesn't support more than one head per cache tensor.");
72+
}
73+
6674
if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE) {
6775
// Allocates on extra copy to accomodate caches sliding forward.
6876
cache_data_size_ = (n_caches_ + 1) * cache_len_ * head_dim_;
6977
} else {
70-
cache_data_size_ = n_caches_ * cache_len_ * head_dim_;
78+
cache_data_size_ =
79+
n_caches_ * n_heads_per_cache_ * cache_len_ * head_dim_;
7180
}
72-
update_data_size_ = n_caches_ * max_input_len_ * head_dim_;
81+
update_data_size_ =
82+
n_caches_ * n_heads_per_cache_ * max_input_len_ * head_dim_;
7383

7484
cache_data_ = allocator_.allocate(cache_data_size_);
7585
update_data_ = allocator_.allocate(update_data_size_);
@@ -111,14 +121,40 @@ class StaticKVCache {
111121
auto outSizes = outMeta->sizes();
112122
ET_CHECK_MSG(inSizes[0] == 1, "Only support batch size 1.");
113123
ET_CHECK_MSG(outSizes[0] == 1, "Only support batch size 1.");
124+
if (n_heads_per_cache_ > 1) {
125+
// More than 1 head per cache, meaning regular MHA is used. Tensor shape
126+
// is (1, n_heads, seq_len, head_dim).
127+
ET_CHECK_MSG(
128+
inSizes.size() == 4, "Cache input tensor expected to have rank 4.");
129+
ET_CHECK_MSG(
130+
outSizes.size() == 4,
131+
"Cache input tensor expected to have rank 4.");
132+
ET_CHECK_MSG(
133+
inSizes[1] == n_heads_per_cache_,
134+
"Number of heads per cache mismatch.");
135+
ET_CHECK_MSG(
136+
outSizes[1] == n_heads_per_cache_,
137+
"Number of heads per cache mismatch.");
138+
} else {
139+
// 1 head per cache, meaning MHA is split up into multiple SHAs for QNN.
140+
// Tensor shape is (1, seq_len, head_dim).
141+
ET_CHECK_MSG(
142+
inSizes.size() == 3, "Cache input tensor expected to have rank 3.");
143+
ET_CHECK_MSG(
144+
outSizes.size() == 3,
145+
"Cache input tensor expected to have rank 3.");
146+
}
147+
auto ndim = inSizes.size();
114148
if (transpose_) {
115-
ET_CHECK_MSG(inSizes[1] == head_dim_, "KV head dim mismatch.");
116-
ET_CHECK_MSG(outSizes[1] == head_dim_, "KV head dim mismatch.");
117-
ET_CHECK_MSG(inSizes[2] == cache_len_, "Cache length dim mismatch.");
149+
ET_CHECK_MSG(inSizes[ndim - 2] == head_dim_, "KV head dim mismatch.");
150+
ET_CHECK_MSG(outSizes[ndim - 2] == head_dim_, "KV head dim mismatch.");
151+
ET_CHECK_MSG(
152+
inSizes[ndim - 1] == cache_len_, "Cache length dim mismatch.");
118153
} else {
119-
ET_CHECK_MSG(inSizes[2] == head_dim_, "KV head dim mismatch.");
120-
ET_CHECK_MSG(outSizes[2] == head_dim_, "KV head dim mismatch.");
121-
ET_CHECK_MSG(inSizes[1] == cache_len_, "Cache length dim mismatch.");
154+
ET_CHECK_MSG(inSizes[ndim - 1] == head_dim_, "KV head dim mismatch.");
155+
ET_CHECK_MSG(outSizes[ndim - 1] == head_dim_, "KV head dim mismatch.");
156+
ET_CHECK_MSG(
157+
inSizes[ndim - 2] == cache_len_, "Cache length dim mismatch.");
122158
}
123159

124160
auto impl = ::executorch::runtime::etensor::TensorImpl(
@@ -180,8 +216,10 @@ class StaticKVCache {
180216
input_ptrs_.resize(n_caches_);
181217
output_ptrs_.resize(n_caches_);
182218
for (size_t i = 0; i < n_caches_; i++) {
183-
input_ptrs_[i] = cache_data_ + i * cache_len_ * head_dim_;
184-
output_ptrs_[i] = update_data_ + i * max_input_len_ * head_dim_;
219+
input_ptrs_[i] =
220+
cache_data_ + i * n_heads_per_cache_ * cache_len_ * head_dim_;
221+
output_ptrs_[i] =
222+
update_data_ + i * n_heads_per_cache_ * max_input_len_ * head_dim_;
185223
}
186224
}
187225

@@ -213,17 +251,23 @@ class StaticKVCache {
213251
const auto& updateTensor =
214252
method.get_output(output_indices[i]).toTensor();
215253
ET_CHECK(output_ptrs_[i] == updateTensor.mutable_data_ptr<T>());
216-
std::copy(
217-
output_ptrs_[i] + update_pos * head_dim_,
218-
output_ptrs_[i] + (update_pos + update_len) * head_dim_,
219-
input_ptrs_[i] + valid_len_ * head_dim_);
254+
auto update_seq_len = updateTensor.size(updateTensor.dim() - 2);
255+
for (size_t j = 0; j < n_heads_per_cache_; j++) {
256+
auto* update_head = output_ptrs_[i] + update_seq_len * head_dim_ * j;
257+
auto* cache_head = input_ptrs_[i] + cache_len_ * head_dim_ * j;
258+
std::copy(
259+
update_head + update_pos * head_dim_,
260+
update_head + (update_pos + update_len) * head_dim_,
261+
cache_head + valid_len_ * head_dim_);
262+
}
220263
}
221264
valid_len_ += update_len;
222265
}
223266

224267
size_t n_caches_;
225268
size_t cache_len_;
226269
size_t max_input_len_;
270+
size_t n_heads_per_cache_;
227271
size_t head_dim_;
228272
bool transpose_;
229273
StaticAttentionUpdateStyle style_;
@@ -414,6 +458,7 @@ class StaticAttentionIOManager {
414458
size_t cache_len{};
415459
size_t head_dim{};
416460
size_t max_input_len{};
461+
size_t n_heads_per_cache{};
417462
size_t attn_mask_input_index{};
418463
size_t rope_freqs_cos_input_index{};
419464
size_t rope_freqs_sin_input_index{};
@@ -434,13 +479,15 @@ class StaticAttentionIOManager {
434479
config_.cache_len,
435480
config_.head_dim,
436481
config_.max_input_len,
482+
config_.n_heads_per_cache,
437483
false,
438484
config_.style),
439485
vCaches_(
440486
config_.n_caches,
441487
config_.cache_len,
442488
config_.head_dim,
443489
config_.max_input_len,
490+
config_.n_heads_per_cache,
444491
false,
445492
config_.style) {
446493
ET_LOG(

0 commit comments

Comments
 (0)