Skip to content

Commit a70c6a3

Browse files
authored
Static attention IO manager
Differential Revision: D69624077 Pull Request resolved: #8486
1 parent 159b932 commit a70c6a3

File tree

2 files changed

+376
-0
lines changed

2 files changed

+376
-0
lines changed
Lines changed: 363 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,363 @@
1+
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
#include <memory>
4+
#include <tuple>
5+
#include <unordered_map>
6+
#include <vector>
7+
8+
#include <executorch/runtime/core/span.h>
9+
#include <executorch/runtime/executor/method.h>
10+
11+
namespace example {
12+
13+
template <typename T, typename AllocatorT = std::allocator<T>>
14+
class StaticKVCache {
15+
public:
16+
/**
17+
* Helper class to handle KV cache I/O. Assumes batch size 1, same context
18+
* length and head dimension for each cache. Supports hybrid operation mixing
19+
* prefill and decode. Create one instance for key caches and another one for
20+
* value caches.
21+
*/
22+
StaticKVCache(
23+
size_t n_caches,
24+
size_t cache_len,
25+
size_t head_dim,
26+
size_t max_input_len = 1,
27+
bool transpose = false)
28+
: n_caches_(n_caches),
29+
cache_len_(cache_len),
30+
max_input_len_(max_input_len),
31+
head_dim_(head_dim),
32+
transpose_(transpose) {
33+
// Updates are appeneded at the end. Need one extra segment to support the
34+
// sliding window.
35+
data_size_ = (n_caches_ + 1) * cache_len_ * head_dim_ + max_input_len_;
36+
data_ = allocator_.allocate(data_size_);
37+
ET_CHECK(data_ != nullptr);
38+
reset();
39+
}
40+
41+
~StaticKVCache() {
42+
allocator_.deallocate(data_, data_size_);
43+
}
44+
45+
/**
46+
* Set up data pointers for the KV cache related inputs and outputs based on
47+
* the current state of the cache. Call StaticKVCache<T>::update or
48+
* StaticKVCache<T>::reset first as needed before calling this function.
49+
*/
50+
void prepare(
51+
torch::executor::Method& method,
52+
const std::vector<size_t>& inputIndices,
53+
const std::vector<size_t>& outputIndices) {
54+
ET_CHECK(inputIndices.size() == outputIndices.size());
55+
auto methodMeta = method.method_meta();
56+
for (size_t i = 0; i < n_caches_; i++) {
57+
auto inIdx = inputIndices[i];
58+
auto outIdx = outputIndices[i];
59+
auto inMeta = methodMeta.input_tensor_meta(inIdx);
60+
auto outMeta = methodMeta.output_tensor_meta(outIdx);
61+
ET_CHECK(inMeta.ok());
62+
ET_CHECK(outMeta.ok());
63+
64+
auto inSizes = inMeta->sizes();
65+
auto outSizes = outMeta->sizes();
66+
ET_CHECK_MSG(inSizes[0] == 1, "Only support batch size 1.");
67+
ET_CHECK_MSG(outSizes[0] == 1, "Only support batch size 1.");
68+
if (transpose_) {
69+
ET_CHECK_MSG(inSizes[1] == head_dim_, "KV head dim mismatch.");
70+
ET_CHECK_MSG(outSizes[1] == head_dim_, "KV head dim mismatch.");
71+
ET_CHECK_MSG(inSizes[2] == cache_len_, "Cache length dim mismatch.");
72+
} else {
73+
ET_CHECK_MSG(inSizes[2] == head_dim_, "KV head dim mismatch.");
74+
ET_CHECK_MSG(outSizes[2] == head_dim_, "KV head dim mismatch.");
75+
ET_CHECK_MSG(inSizes[1] == cache_len_, "Cache length dim mismatch.");
76+
}
77+
78+
auto impl = ::executorch::runtime::etensor::TensorImpl(
79+
inMeta->scalar_type(),
80+
inMeta->sizes().size(),
81+
const_cast<torch::executor::TensorImpl::SizesType*>(
82+
inMeta->sizes().data()),
83+
input_ptrs_[i],
84+
const_cast<torch::executor::TensorImpl::DimOrderType*>(
85+
inMeta->dim_order().data()));
86+
torch::executor::Tensor t(&impl);
87+
ET_CHECK(method.set_input(t, inIdx) == torch::executor::Error::Ok);
88+
ET_CHECK(
89+
method.set_output_data_ptr(
90+
output_ptrs_[i], outMeta->nbytes(), outIdx) ==
91+
torch::executor::Error::Ok);
92+
}
93+
}
94+
95+
/**
96+
* Update the internal data pointers using the cache updates returned by the
97+
* model. This length of each individual update cannot exceed the max update
98+
* length specified during the creation, and the total length cannot exceed
99+
* the context length.
100+
*/
101+
void update(
102+
torch::executor::Method& method,
103+
const std::vector<size_t>& outputIndices,
104+
size_t update_len) {
105+
if (valid_len_ + update_len > cache_len_) {
106+
throw std::runtime_error("Cache capacity exceeded.");
107+
}
108+
109+
if (transpose_) {
110+
throw std::runtime_error("Not implemented.");
111+
} else {
112+
updateSeqDim(method, outputIndices, update_len);
113+
}
114+
valid_len_ += update_len;
115+
}
116+
117+
/**
118+
* Reset the cache. After this the cache contains no valid data and is ready
119+
* for number of tokens up to the context length.
120+
*/
121+
void reset() {
122+
valid_len_ = 0;
123+
if (transpose_) {
124+
throw std::runtime_error("Not implemented.");
125+
} else {
126+
initSeqDim();
127+
}
128+
}
129+
130+
private:
131+
void initSeqDim() {
132+
auto cacheSize = cache_len_ * head_dim_;
133+
input_ptrs_.resize(n_caches_);
134+
output_ptrs_.resize(n_caches_);
135+
for (size_t i = 0; i < n_caches_; i++) {
136+
input_ptrs_[i] = data_ + i * cacheSize;
137+
output_ptrs_[i] = input_ptrs_[i] + cacheSize;
138+
}
139+
}
140+
141+
void updateSeqDim(
142+
torch::executor::Method& method,
143+
const std::vector<size_t>& outputIndices,
144+
size_t update_len) {
145+
ET_CHECK(n_caches_ == outputIndices.size());
146+
for (size_t i = 0; i < n_caches_; i++) {
147+
const auto& updateTensor = method.get_output(outputIndices[i]).toTensor();
148+
ET_CHECK(
149+
input_ptrs_[i] + cache_len_ * head_dim_ ==
150+
updateTensor.mutable_data_ptr<T>());
151+
152+
input_ptrs_[i] += update_len * head_dim_;
153+
output_ptrs_[i] += update_len * head_dim_;
154+
}
155+
}
156+
157+
// std::vector<T> pool_;
158+
size_t n_caches_;
159+
size_t cache_len_;
160+
size_t max_input_len_;
161+
size_t head_dim_;
162+
bool transpose_;
163+
AllocatorT allocator_;
164+
size_t data_size_;
165+
T* data_;
166+
std::vector<T*> input_ptrs_;
167+
std::vector<T*> output_ptrs_;
168+
size_t valid_len_ = 0;
169+
};
170+
171+
template <typename T, typename AllocatorT = std::allocator<T>>
172+
class StaticAttentionMask {
173+
public:
174+
/**
175+
* Manages the attention mask in the same style of KV cache IO where valid
176+
* data is at the end of the cache. The mask has shape (1, maxSeqLen,
177+
* cache_len
178+
* + maxSeqLen) where maxSeqLen is 1 for decode or the prefill length. Accepts
179+
* zero_val and mask_val (which represents -inf) to support quantized mask.
180+
*
181+
* This class manages the slice of the mask at [:, :, : (cache_len -
182+
* validCacheLen)]. User can update the rest of the mask to implement causal
183+
* masking for example.
184+
*/
185+
StaticAttentionMask(
186+
size_t cache_len,
187+
size_t input_len,
188+
size_t head_dim,
189+
T zero_val,
190+
T mask_val)
191+
: cache_len_(cache_len),
192+
input_len_(input_len),
193+
head_dim_(head_dim),
194+
cache_mask_len_(cache_len_),
195+
zero_val_(zero_val),
196+
mask_val_(mask_val) {
197+
data_size_ = input_len_ * (cache_len_ + input_len_);
198+
data_ = allocator_.allocate(data_size_);
199+
ET_CHECK(data_ != nullptr);
200+
reset();
201+
}
202+
203+
/**
204+
* Reset the mask to the state where the cache contains no valid data.
205+
*/
206+
void reset() {
207+
cache_mask_len_ = cache_len_;
208+
for (size_t i = 0; i < input_len_; i++) {
209+
auto* p = data_ + (cache_len_ + input_len_) * i;
210+
std::fill(p, p + cache_len_, mask_val_);
211+
}
212+
}
213+
214+
/**
215+
* Update the mask to indicate update_len elements have been added to the
216+
* cache. Note that update_len might be smaller than maxSeqLen when prefilling
217+
* with padded inputs.
218+
*/
219+
void updateCacheMask(size_t update_len) {
220+
for (size_t i = 0; i < input_len_; i++) {
221+
auto* p = data_ + (cache_len_ + input_len_) * i;
222+
std::fill(
223+
p + cache_mask_len_ - update_len, p + cache_mask_len_, zero_val_);
224+
}
225+
cache_mask_len_ -= update_len;
226+
}
227+
228+
void setCausalMask() {
229+
for (size_t i = 0; i < input_len_ - 1; i++) {
230+
auto* p = data_ + (cache_len_ + input_len_) * i;
231+
std::fill(p + cache_len_, p + cache_len_ + 1 + i, zero_val_);
232+
std::fill(p + cache_len_ + 1 + i, p + cache_len_ + input_len_, mask_val_);
233+
}
234+
}
235+
236+
T* get() {
237+
return data_;
238+
}
239+
240+
private:
241+
size_t cache_len_;
242+
size_t input_len_;
243+
size_t head_dim_;
244+
size_t cache_mask_len_;
245+
T zero_val_;
246+
T mask_val_;
247+
AllocatorT allocator_;
248+
size_t data_size_ = 0;
249+
T* data_;
250+
};
251+
252+
template <
253+
typename CacheT,
254+
typename MaskT,
255+
typename RopeT,
256+
typename CacheAllocatorT = std::allocator<CacheT>,
257+
typename MaskAllocatorT = std::allocator<MaskT>>
258+
class StaticAttentionIOManager {
259+
public:
260+
StaticAttentionIOManager(
261+
size_t n_caches,
262+
size_t cache_len,
263+
size_t head_dim,
264+
size_t max_input_len,
265+
size_t rope_freqs_cos_index,
266+
size_t rope_freqs_sin_index,
267+
RopeT* rope_freqs_cos,
268+
RopeT* rope_freqs_sin)
269+
: cache_len_(cache_len),
270+
head_dim_(head_dim),
271+
kCaches_(n_caches, cache_len, head_dim, max_input_len),
272+
vCaches_(n_caches, cache_len, head_dim, max_input_len),
273+
rope_freqs_cos_index_(rope_freqs_cos_index),
274+
rope_freqs_sin_index_(rope_freqs_sin_index),
275+
rope_freqs_cos_(rope_freqs_cos),
276+
rope_freqs_sin_(rope_freqs_sin) {}
277+
278+
StaticAttentionMask<MaskT, MaskAllocatorT>&
279+
addMask(size_t input_len, MaskT zero_val, MaskT mask_val) {
280+
auto it = attentionMasks_.emplace(
281+
std::piecewise_construct,
282+
std::forward_as_tuple(input_len),
283+
std::forward_as_tuple(
284+
cache_len_, input_len, head_dim_, zero_val, mask_val));
285+
return it.first->second;
286+
}
287+
288+
StaticAttentionMask<MaskT, MaskAllocatorT>& getMask(size_t input_len) {
289+
return attentionMasks_.at(input_len);
290+
}
291+
292+
void prepare(
293+
torch::executor::Method& method,
294+
const std::vector<size_t>& k_cache_input_indices,
295+
const std::vector<size_t>& k_cache_output_indices,
296+
const std::vector<size_t>& v_cache_input_indices,
297+
const std::vector<size_t>& v_cache_output_indices) {
298+
kCaches_.prepare(method, k_cache_input_indices, k_cache_output_indices);
299+
vCaches_.prepare(method, v_cache_input_indices, v_cache_output_indices);
300+
set_input(
301+
method,
302+
rope_freqs_cos_index_,
303+
rope_freqs_cos_ + input_pos_ * head_dim_ / 2);
304+
set_input(
305+
method,
306+
rope_freqs_sin_index_,
307+
rope_freqs_sin_ + input_pos_ * head_dim_ / 2);
308+
}
309+
310+
void update(
311+
torch::executor::Method& method,
312+
const std::vector<size_t>& k_cache_output_indices,
313+
const std::vector<size_t>& v_cache_output_indices,
314+
size_t update_len) {
315+
input_pos_ += update_len;
316+
kCaches_.update(method, k_cache_output_indices, update_len);
317+
vCaches_.update(method, v_cache_output_indices, update_len);
318+
for (auto it : attentionMasks_) {
319+
it.second.updateCacheMask(update_len);
320+
}
321+
}
322+
323+
void reset() {
324+
input_pos_ = 0;
325+
kCaches_.reset();
326+
vCaches_.reset();
327+
for (auto it : attentionMasks_) {
328+
it.second.reset();
329+
}
330+
}
331+
332+
private:
333+
template <typename T>
334+
void set_input(executorch::runtime::Method& method, size_t idx, T* data) {
335+
auto methodMeta = method.method_meta();
336+
auto inputMeta = methodMeta.input_tensor_meta(idx);
337+
auto impl = ::executorch::runtime::etensor::TensorImpl(
338+
inputMeta->scalar_type(),
339+
inputMeta->sizes().size(),
340+
const_cast<executorch::aten::TensorImpl::SizesType*>(
341+
inputMeta->sizes().data()),
342+
data,
343+
const_cast<executorch::aten::TensorImpl::DimOrderType*>(
344+
inputMeta->dim_order().data()));
345+
executorch::runtime::etensor::Tensor t(&impl);
346+
ET_CHECK(method.set_input(t, idx) == executorch::runtime::Error::Ok);
347+
}
348+
349+
size_t cache_len_;
350+
size_t input_len_;
351+
size_t head_dim_;
352+
size_t input_pos_;
353+
StaticKVCache<CacheT, CacheAllocatorT> kCaches_;
354+
StaticKVCache<CacheT, CacheAllocatorT> vCaches_;
355+
std::unordered_map<size_t, StaticAttentionMask<MaskT, MaskAllocatorT>>
356+
attentionMasks_;
357+
size_t rope_freqs_cos_index_;
358+
size_t rope_freqs_sin_index_;
359+
RopeT* rope_freqs_cos_;
360+
RopeT* rope_freqs_sin_;
361+
};
362+
363+
} // namespace example

examples/models/llama/runner/targets.bzl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,16 @@ def define_common_targets():
6161
"libtorch",
6262
] if aten else [],
6363
)
64+
65+
runtime.cxx_library(
66+
name = "static_attention_io_manager",
67+
exported_headers = [
68+
"static_attention_io_manager.h",
69+
],
70+
visibility = [
71+
"@EXECUTORCH_CLIENTS",
72+
],
73+
exported_deps = [
74+
"//executorch/runtime/executor:program",
75+
]
76+
)

0 commit comments

Comments
 (0)