Skip to content

Commit 5ab0a92

Browse files
sxufacebook-github-bot
authored andcommitted
StaticAttentionIOManager: remove unsupported options
Summary: Simplify code by removing unsupported options: - Storing cache transposed: not implemented doesn't seem to be a need so far. - Sliding cache window: cannot support local-global attention, removed. Differential Revision: D79108366
1 parent 03f6bcc commit 5ab0a92

File tree

1 file changed

+27
-106
lines changed

1 file changed

+27
-106
lines changed

examples/models/llama/runner/static_attention_io_manager.h

Lines changed: 27 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,9 @@
1919
namespace example {
2020

2121
enum class StaticAttentionUpdateStyle {
22-
/**
23-
* KV caches will have valid data at the end of the cache. New elements are
24-
* added at the end and the start of the cache will slide forward to maintain
25-
* this invariant. This potentially allows shorter caches to be passed into
26-
* the model by adjusting the start pointer.
27-
*/
28-
SLIDING_CACHE,
2922
/**
3023
* I/O pointers do not change which can enable persistent memory mapping
31-
* between AP and NPU.
24+
* between AP and NPU. However cache updates need to be copied.
3225
*/
3326
SMART_MASK,
3427
};
@@ -49,35 +42,17 @@ class StaticKVCache {
4942
size_t head_dim,
5043
size_t max_input_len = 1,
5144
size_t n_heads_per_cache = 1,
52-
bool transpose = false,
5345
StaticAttentionUpdateStyle style =
54-
StaticAttentionUpdateStyle::SLIDING_CACHE)
46+
StaticAttentionUpdateStyle::SMART_MASK)
5547
: n_caches_(n_caches),
5648
cache_len_(cache_len),
5749
max_input_len_(max_input_len),
5850
n_heads_per_cache_(n_heads_per_cache),
5951
head_dim_(head_dim),
60-
transpose_(transpose),
6152
style_(style),
6253
input_ptrs_(n_caches_),
6354
output_ptrs_(n_caches_) {
64-
if (transpose_) {
65-
throw std::runtime_error("Not implemented.");
66-
}
67-
68-
if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE &&
69-
n_heads_per_cache_ > 1) {
70-
throw std::runtime_error(
71-
"Sliding cache update strategy doesn't support more than one head per cache tensor.");
72-
}
73-
74-
if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE) {
75-
// Allocates on extra copy to accomodate caches sliding forward.
76-
cache_data_size_ = (n_caches_ + 1) * cache_len_ * head_dim_;
77-
} else {
78-
cache_data_size_ =
79-
n_caches_ * n_heads_per_cache_ * cache_len_ * head_dim_;
80-
}
55+
cache_data_size_ = n_caches_ * n_heads_per_cache_ * cache_len_ * head_dim_;
8156
update_data_size_ =
8257
n_caches_ * n_heads_per_cache_ * max_input_len_ * head_dim_;
8358

@@ -145,17 +120,10 @@ class StaticKVCache {
145120
"Cache input tensor expected to have rank 3.");
146121
}
147122
auto ndim = inSizes.size();
148-
if (transpose_) {
149-
ET_CHECK_MSG(inSizes[ndim - 2] == head_dim_, "KV head dim mismatch.");
150-
ET_CHECK_MSG(outSizes[ndim - 2] == head_dim_, "KV head dim mismatch.");
151-
ET_CHECK_MSG(
152-
inSizes[ndim - 1] == cache_len_, "Cache length dim mismatch.");
153-
} else {
154-
ET_CHECK_MSG(inSizes[ndim - 1] == head_dim_, "KV head dim mismatch.");
155-
ET_CHECK_MSG(outSizes[ndim - 1] == head_dim_, "KV head dim mismatch.");
156-
ET_CHECK_MSG(
157-
inSizes[ndim - 2] == cache_len_, "Cache length dim mismatch.");
158-
}
123+
ET_CHECK_MSG(inSizes[ndim - 1] == head_dim_, "KV head dim mismatch.");
124+
ET_CHECK_MSG(outSizes[ndim - 1] == head_dim_, "KV head dim mismatch.");
125+
ET_CHECK_MSG(
126+
inSizes[ndim - 2] == cache_len_, "Cache length dim mismatch.");
159127

160128
auto impl = ::executorch::runtime::etensor::TensorImpl(
161129
inMeta->scalar_type(),
@@ -189,11 +157,21 @@ class StaticKVCache {
189157
throw std::runtime_error("Cache capacity exceeded.");
190158
}
191159

192-
if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE) {
193-
update_sliding_cache(method, output_indices, update_len, update_pos);
194-
} else {
195-
update_smart_mask(method, output_indices, update_len, update_pos);
160+
for (size_t i = 0; i < n_caches_; i++) {
161+
const auto& updateTensor =
162+
method.get_output(output_indices[i]).toTensor();
163+
ET_CHECK(output_ptrs_[i] == updateTensor.mutable_data_ptr<T>());
164+
auto update_seq_len = updateTensor.size(updateTensor.dim() - 2);
165+
for (size_t j = 0; j < n_heads_per_cache_; j++) {
166+
auto* update_head = output_ptrs_[i] + update_seq_len * head_dim_ * j;
167+
auto* cache_head = input_ptrs_[i] + cache_len_ * head_dim_ * j;
168+
std::copy(
169+
update_head + update_pos * head_dim_,
170+
update_head + (update_pos + update_len) * head_dim_,
171+
cache_head + valid_len_ * head_dim_);
172+
}
196173
}
174+
valid_len_ += update_len;
197175
}
198176

199177
/**
@@ -202,9 +180,6 @@ class StaticKVCache {
202180
*/
203181
void reset() {
204182
valid_len_ = 0;
205-
if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE) {
206-
init_ptrs();
207-
}
208183
}
209184

210185
size_t size() {
@@ -223,53 +198,11 @@ class StaticKVCache {
223198
}
224199
}
225200

226-
void update_sliding_cache(
227-
torch::executor::Method& method,
228-
const std::vector<size_t>& output_indices,
229-
size_t update_len,
230-
size_t update_pos) {
231-
ET_CHECK(n_caches_ == output_indices.size());
232-
for (size_t i = 0; i < n_caches_; i++) {
233-
const auto& updateTensor =
234-
method.get_output(output_indices[i]).toTensor();
235-
ET_CHECK(output_ptrs_[i] == updateTensor.const_data_ptr<T>());
236-
std::copy(
237-
output_ptrs_[i] + update_pos * head_dim_,
238-
output_ptrs_[i] + (update_pos + update_len) * head_dim_,
239-
input_ptrs_[i] + cache_len_ * head_dim_);
240-
input_ptrs_[i] += update_len * head_dim_;
241-
}
242-
valid_len_ += update_len;
243-
}
244-
245-
void update_smart_mask(
246-
torch::executor::Method& method,
247-
const std::vector<size_t>& output_indices,
248-
size_t update_len,
249-
size_t update_pos) {
250-
for (size_t i = 0; i < n_caches_; i++) {
251-
const auto& updateTensor =
252-
method.get_output(output_indices[i]).toTensor();
253-
ET_CHECK(output_ptrs_[i] == updateTensor.mutable_data_ptr<T>());
254-
auto update_seq_len = updateTensor.size(updateTensor.dim() - 2);
255-
for (size_t j = 0; j < n_heads_per_cache_; j++) {
256-
auto* update_head = output_ptrs_[i] + update_seq_len * head_dim_ * j;
257-
auto* cache_head = input_ptrs_[i] + cache_len_ * head_dim_ * j;
258-
std::copy(
259-
update_head + update_pos * head_dim_,
260-
update_head + (update_pos + update_len) * head_dim_,
261-
cache_head + valid_len_ * head_dim_);
262-
}
263-
}
264-
valid_len_ += update_len;
265-
}
266-
267201
size_t n_caches_;
268202
size_t cache_len_;
269203
size_t max_input_len_;
270204
size_t n_heads_per_cache_;
271205
size_t head_dim_;
272-
bool transpose_;
273206
StaticAttentionUpdateStyle style_;
274207
AllocatorT allocator_;
275208
size_t cache_data_size_;
@@ -301,7 +234,7 @@ class StaticAttentionMask {
301234
T zero_val,
302235
T mask_val,
303236
StaticAttentionUpdateStyle style =
304-
StaticAttentionUpdateStyle::SLIDING_CACHE)
237+
StaticAttentionUpdateStyle::SMART_MASK)
305238
: cache_len_(cache_len),
306239
input_len_(input_len),
307240
head_dim_(head_dim),
@@ -341,20 +274,10 @@ class StaticAttentionMask {
341274
* prefilling with padded inputs.
342275
*/
343276
void unmask(size_t update_len) {
344-
if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE) {
345-
for (size_t i = 0; i < input_len_; i++) {
346-
auto* p = data_ + (cache_len_ + input_len_) * i;
347-
std::fill(
348-
p + cache_len_ - cache_valid_len_ - update_len,
349-
p + cache_len_ - cache_valid_len_,
350-
zero_val_);
351-
}
352-
} else {
353-
for (size_t i = 0; i < input_len_; i++) {
354-
auto* p = data_ + (cache_len_ + input_len_) * i;
355-
std::fill(
356-
p + cache_valid_len_, p + cache_valid_len_ + update_len, zero_val_);
357-
}
277+
for (size_t i = 0; i < input_len_; i++) {
278+
auto* p = data_ + (cache_len_ + input_len_) * i;
279+
std::fill(
280+
p + cache_valid_len_, p + cache_valid_len_ + update_len, zero_val_);
358281
}
359282
cache_valid_len_ += update_len;
360283
}
@@ -469,7 +392,7 @@ class StaticAttentionIOManager {
469392
RopeT* rope_freqs_cos;
470393
RopeT* rope_freqs_sin;
471394
StaticAttentionUpdateStyle style =
472-
StaticAttentionUpdateStyle::SLIDING_CACHE;
395+
StaticAttentionUpdateStyle::SMART_MASK;
473396
};
474397

475398
StaticAttentionIOManager(StaticAttentionIOConfig config)
@@ -480,15 +403,13 @@ class StaticAttentionIOManager {
480403
config_.head_dim,
481404
config_.max_input_len,
482405
config_.n_heads_per_cache,
483-
false,
484406
config_.style),
485407
vCaches_(
486408
config_.n_caches,
487409
config_.cache_len,
488410
config_.head_dim,
489411
config_.max_input_len,
490412
config_.n_heads_per_cache,
491-
false,
492413
config_.style) {
493414
ET_LOG(
494415
Info,

0 commit comments

Comments
 (0)