Skip to content

Commit 0c8879e

Browse files
authored
StaticAttentionIOManager: remove unsupported options
Differential Revision: D79108366 Pull Request resolved: #12931
1 parent 339e95f commit 0c8879e

File tree

1 file changed

+27
-109
lines changed

1 file changed

+27
-109
lines changed

examples/models/llama/runner/static_attention_io_manager.h

Lines changed: 27 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,9 @@
1919
namespace example {
2020

2121
enum class StaticAttentionUpdateStyle {
22-
/**
23-
* KV caches will have valid data at the end of the cache. New elements are
24-
* added at the end and the start of the cache will slide forward to maintain
25-
* this invariant. This potentially allows shorter caches to be passed into
26-
* the model by adjusting the start pointer.
27-
*/
28-
SLIDING_CACHE,
2922
/**
3023
* I/O pointers do not change which can enable persistent memory mapping
31-
* between AP and NPU.
24+
* between AP and NPU. However cache updates need to be copied.
3225
*/
3326
SMART_MASK,
3427
};
@@ -49,35 +42,16 @@ class StaticKVCache {
4942
size_t head_dim,
5043
size_t max_input_len = 1,
5144
size_t n_heads_per_cache = 1,
52-
bool transpose = false,
53-
StaticAttentionUpdateStyle style =
54-
StaticAttentionUpdateStyle::SLIDING_CACHE)
45+
StaticAttentionUpdateStyle style = StaticAttentionUpdateStyle::SMART_MASK)
5546
: n_caches_(n_caches),
5647
cache_len_(cache_len),
5748
max_input_len_(max_input_len),
5849
n_heads_per_cache_(n_heads_per_cache),
5950
head_dim_(head_dim),
60-
transpose_(transpose),
6151
style_(style),
6252
input_ptrs_(n_caches_),
6353
output_ptrs_(n_caches_) {
64-
if (transpose_) {
65-
throw std::runtime_error("Not implemented.");
66-
}
67-
68-
if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE &&
69-
n_heads_per_cache_ > 1) {
70-
throw std::runtime_error(
71-
"Sliding cache update strategy doesn't support more than one head per cache tensor.");
72-
}
73-
74-
if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE) {
75-
// Allocates on extra copy to accomodate caches sliding forward.
76-
cache_data_size_ = (n_caches_ + 1) * cache_len_ * head_dim_;
77-
} else {
78-
cache_data_size_ =
79-
n_caches_ * n_heads_per_cache_ * cache_len_ * head_dim_;
80-
}
54+
cache_data_size_ = n_caches_ * n_heads_per_cache_ * cache_len_ * head_dim_;
8155
update_data_size_ =
8256
n_caches_ * n_heads_per_cache_ * max_input_len_ * head_dim_;
8357

@@ -145,17 +119,10 @@ class StaticKVCache {
145119
"Cache input tensor expected to have rank 3.");
146120
}
147121
auto ndim = inSizes.size();
148-
if (transpose_) {
149-
ET_CHECK_MSG(inSizes[ndim - 2] == head_dim_, "KV head dim mismatch.");
150-
ET_CHECK_MSG(outSizes[ndim - 2] == head_dim_, "KV head dim mismatch.");
151-
ET_CHECK_MSG(
152-
inSizes[ndim - 1] == cache_len_, "Cache length dim mismatch.");
153-
} else {
154-
ET_CHECK_MSG(inSizes[ndim - 1] == head_dim_, "KV head dim mismatch.");
155-
ET_CHECK_MSG(outSizes[ndim - 1] == head_dim_, "KV head dim mismatch.");
156-
ET_CHECK_MSG(
157-
inSizes[ndim - 2] == cache_len_, "Cache length dim mismatch.");
158-
}
122+
ET_CHECK_MSG(inSizes[ndim - 1] == head_dim_, "KV head dim mismatch.");
123+
ET_CHECK_MSG(outSizes[ndim - 1] == head_dim_, "KV head dim mismatch.");
124+
ET_CHECK_MSG(
125+
inSizes[ndim - 2] == cache_len_, "Cache length dim mismatch.");
159126

160127
auto impl = ::executorch::runtime::etensor::TensorImpl(
161128
inMeta->scalar_type(),
@@ -189,11 +156,21 @@ class StaticKVCache {
189156
throw std::runtime_error("Cache capacity exceeded.");
190157
}
191158

192-
if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE) {
193-
update_sliding_cache(method, output_indices, update_len, update_pos);
194-
} else {
195-
update_smart_mask(method, output_indices, update_len, update_pos);
159+
for (size_t i = 0; i < n_caches_; i++) {
160+
const auto& updateTensor =
161+
method.get_output(output_indices[i]).toTensor();
162+
ET_CHECK(output_ptrs_[i] == updateTensor.mutable_data_ptr<T>());
163+
auto update_seq_len = updateTensor.size(updateTensor.dim() - 2);
164+
for (size_t j = 0; j < n_heads_per_cache_; j++) {
165+
auto* update_head = output_ptrs_[i] + update_seq_len * head_dim_ * j;
166+
auto* cache_head = input_ptrs_[i] + cache_len_ * head_dim_ * j;
167+
std::copy(
168+
update_head + update_pos * head_dim_,
169+
update_head + (update_pos + update_len) * head_dim_,
170+
cache_head + valid_len_ * head_dim_);
171+
}
196172
}
173+
valid_len_ += update_len;
197174
}
198175

199176
/**
@@ -202,9 +179,6 @@ class StaticKVCache {
202179
*/
203180
void reset() {
204181
valid_len_ = 0;
205-
if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE) {
206-
init_ptrs();
207-
}
208182
}
209183

210184
size_t size() {
@@ -223,53 +197,11 @@ class StaticKVCache {
223197
}
224198
}
225199

226-
void update_sliding_cache(
227-
torch::executor::Method& method,
228-
const std::vector<size_t>& output_indices,
229-
size_t update_len,
230-
size_t update_pos) {
231-
ET_CHECK(n_caches_ == output_indices.size());
232-
for (size_t i = 0; i < n_caches_; i++) {
233-
const auto& updateTensor =
234-
method.get_output(output_indices[i]).toTensor();
235-
ET_CHECK(output_ptrs_[i] == updateTensor.const_data_ptr<T>());
236-
std::copy(
237-
output_ptrs_[i] + update_pos * head_dim_,
238-
output_ptrs_[i] + (update_pos + update_len) * head_dim_,
239-
input_ptrs_[i] + cache_len_ * head_dim_);
240-
input_ptrs_[i] += update_len * head_dim_;
241-
}
242-
valid_len_ += update_len;
243-
}
244-
245-
void update_smart_mask(
246-
torch::executor::Method& method,
247-
const std::vector<size_t>& output_indices,
248-
size_t update_len,
249-
size_t update_pos) {
250-
for (size_t i = 0; i < n_caches_; i++) {
251-
const auto& updateTensor =
252-
method.get_output(output_indices[i]).toTensor();
253-
ET_CHECK(output_ptrs_[i] == updateTensor.mutable_data_ptr<T>());
254-
auto update_seq_len = updateTensor.size(updateTensor.dim() - 2);
255-
for (size_t j = 0; j < n_heads_per_cache_; j++) {
256-
auto* update_head = output_ptrs_[i] + update_seq_len * head_dim_ * j;
257-
auto* cache_head = input_ptrs_[i] + cache_len_ * head_dim_ * j;
258-
std::copy(
259-
update_head + update_pos * head_dim_,
260-
update_head + (update_pos + update_len) * head_dim_,
261-
cache_head + valid_len_ * head_dim_);
262-
}
263-
}
264-
valid_len_ += update_len;
265-
}
266-
267200
size_t n_caches_;
268201
size_t cache_len_;
269202
size_t max_input_len_;
270203
size_t n_heads_per_cache_;
271204
size_t head_dim_;
272-
bool transpose_;
273205
StaticAttentionUpdateStyle style_;
274206
AllocatorT allocator_;
275207
size_t cache_data_size_;
@@ -300,8 +232,7 @@ class StaticAttentionMask {
300232
size_t head_dim,
301233
T zero_val,
302234
T mask_val,
303-
StaticAttentionUpdateStyle style =
304-
StaticAttentionUpdateStyle::SLIDING_CACHE)
235+
StaticAttentionUpdateStyle style = StaticAttentionUpdateStyle::SMART_MASK)
305236
: cache_len_(cache_len),
306237
input_len_(input_len),
307238
head_dim_(head_dim),
@@ -341,20 +272,10 @@ class StaticAttentionMask {
341272
* prefilling with padded inputs.
342273
*/
343274
void unmask(size_t update_len) {
344-
if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE) {
345-
for (size_t i = 0; i < input_len_; i++) {
346-
auto* p = data_ + (cache_len_ + input_len_) * i;
347-
std::fill(
348-
p + cache_len_ - cache_valid_len_ - update_len,
349-
p + cache_len_ - cache_valid_len_,
350-
zero_val_);
351-
}
352-
} else {
353-
for (size_t i = 0; i < input_len_; i++) {
354-
auto* p = data_ + (cache_len_ + input_len_) * i;
355-
std::fill(
356-
p + cache_valid_len_, p + cache_valid_len_ + update_len, zero_val_);
357-
}
275+
for (size_t i = 0; i < input_len_; i++) {
276+
auto* p = data_ + (cache_len_ + input_len_) * i;
277+
std::fill(
278+
p + cache_valid_len_, p + cache_valid_len_ + update_len, zero_val_);
358279
}
359280
cache_valid_len_ += update_len;
360281
}
@@ -468,8 +389,7 @@ class StaticAttentionIOManager {
468389
std::vector<size_t> v_cache_output_indices;
469390
RopeT* rope_freqs_cos;
470391
RopeT* rope_freqs_sin;
471-
StaticAttentionUpdateStyle style =
472-
StaticAttentionUpdateStyle::SLIDING_CACHE;
392+
StaticAttentionUpdateStyle style = StaticAttentionUpdateStyle::SMART_MASK;
473393
};
474394

475395
StaticAttentionIOManager(StaticAttentionIOConfig config)
@@ -480,15 +400,13 @@ class StaticAttentionIOManager {
480400
config_.head_dim,
481401
config_.max_input_len,
482402
config_.n_heads_per_cache,
483-
false,
484403
config_.style),
485404
vCaches_(
486405
config_.n_caches,
487406
config_.cache_len,
488407
config_.head_dim,
489408
config_.max_input_len,
490409
config_.n_heads_per_cache,
491-
false,
492410
config_.style) {
493411
ET_LOG(
494412
Info,

0 commit comments

Comments
 (0)