1616
1717namespace example {
1818
19+ enum class StaticAttentionUpdateStyle {
20+ /* *
21+ * KV caches will have valid data at the end of the cache. New elements are
22+ * added at the end and the start of the cache will slide forward to maintain
23+ * this invariant. This potentially allows shorter caches to be passed into
24+ * the model by adjusting the start pointer.
25+ */
26+ SLIDING_CACHE,
27+ /* *
28+ * I/O pointers do not change which can enable persistent memory mapping
29+ * between AP and NPU. Can also implemente circular cache by adjusting the
30+ * attention mask accordingly.
31+ */
32+ SMART_MASK,
33+ };
34+
1935template <typename T, typename AllocatorT = std::allocator<T>>
2036class StaticKVCache {
2137 public:
2238 /* *
23- * Helper class to handle KV cache I/O. Assumes batch size 1, same context
24- * length and head dimension for each cache. Supports hybrid operation mixing
25- * prefill and decode. Create one instance for key caches and another one for
26- * value caches.
39+ * Helper class to handle KV cache I/O. Assumes batch size 1, same length and
40+ * head dimension for each cache. Supports multi-turn operation mixing prefill
41+ * and decode by sharing the same cache between methods with different input
42+ * length. Create one instance for key caches and another one for value
43+ * caches.
2744 */
2845 StaticKVCache (
2946 size_t n_caches,
3047 size_t cache_len,
3148 size_t head_dim,
3249 size_t max_input_len = 1 ,
33- bool transpose = false )
50+ bool transpose = false ,
51+ StaticAttentionUpdateStyle style =
52+ StaticAttentionUpdateStyle::SLIDING_CACHE)
3453 : n_caches_(n_caches),
3554 cache_len_ (cache_len),
3655 max_input_len_(max_input_len),
3756 head_dim_(head_dim),
38- transpose_(transpose) {
39- // Updates are appeneded at the end. Need one extra segment to support the
40- // sliding window.
41- data_size_ = (n_caches_ + 1 ) * cache_len_ * head_dim_ + max_input_len_;
42- data_ = allocator_.allocate (data_size_);
43- ET_CHECK (data_ != nullptr );
44- reset ();
57+ transpose_(transpose),
58+ style_(style),
59+ input_ptrs_(n_caches_),
60+ output_ptrs_(n_caches_) {
61+ if (transpose_) {
62+ throw std::runtime_error (" Not implemented." );
63+ }
64+
65+ if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE) {
66+ // Allocates on extra copy to accomodate caches sliding forward.
67+ cache_data_size_ = (n_caches_ + 1 ) * cache_len_ * head_dim_;
68+ } else {
69+ cache_data_size_ = n_caches_ * cache_len_ * head_dim_;
70+ }
71+ update_data_size_ = n_caches_ * max_input_len_ * head_dim_;
72+
73+ cache_data_ = allocator_.allocate (cache_data_size_);
74+ update_data_ = allocator_.allocate (update_data_size_);
75+ ET_CHECK (cache_data_ != nullptr );
76+ ET_CHECK (update_data_ != nullptr );
77+ init_ptrs ();
4578 }
4679
4780 StaticKVCache (const StaticKVCache& other) = delete;
@@ -50,23 +83,24 @@ class StaticKVCache {
5083 StaticKVCache& operator =(StaticKVCache&& other) = delete ;
5184
5285 ~StaticKVCache () {
53- allocator_.deallocate (data_, data_size_);
86+ allocator_.deallocate (cache_data_, cache_data_size_);
87+ allocator_.deallocate (update_data_, update_data_size_);
5488 }
5589
5690 /* *
5791 * Set up data pointers for the KV cache related inputs and outputs based on
5892 * the current state of the cache. Call StaticKVCache<T>::update or
59- * StaticKVCache<T>::reset first as needed before calling this function.
93+ * StaticKVCache<T>::reset as needed before calling this function.
6094 */
6195 void prepare (
6296 torch::executor::Method& method,
6397 const std::vector<size_t >& inputIndices,
64- const std::vector<size_t >& outputIndices ) {
65- ET_CHECK (inputIndices.size () == outputIndices .size ());
98+ const std::vector<size_t >& output_indices ) {
99+ ET_CHECK (inputIndices.size () == output_indices .size ());
66100 auto methodMeta = method.method_meta ();
67101 for (size_t i = 0 ; i < n_caches_; i++) {
68102 auto inIdx = inputIndices[i];
69- auto outIdx = outputIndices [i];
103+ auto outIdx = output_indices [i];
70104 auto inMeta = methodMeta.input_tensor_meta (inIdx);
71105 auto outMeta = methodMeta.output_tensor_meta (outIdx);
72106 ET_CHECK (inMeta.ok ());
@@ -106,74 +140,90 @@ class StaticKVCache {
106140 /* *
107141 * Update the internal data pointers using the cache updates returned by the
108142 * model. This length of each individual update cannot exceed the max update
109- * length specified during the creation, and the total length cannot exceed
110- * the context length.
143+ * length specified during creation, and the total length cannot exceed the
144+ * cache length.
111145 */
112146 void update (
113147 torch::executor::Method& method,
114- const std::vector<size_t >& outputIndices ,
148+ const std::vector<size_t >& output_indices ,
115149 size_t update_len) {
116150 if (valid_len_ + update_len > cache_len_) {
117151 throw std::runtime_error (" Cache capacity exceeded." );
118152 }
119153
120- if (transpose_ ) {
121- throw std::runtime_error ( " Not implemented. " );
154+ if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE ) {
155+ update_sliding_cache (method, output_indices, update_len );
122156 } else {
123- updateSeqDim (method, outputIndices , update_len);
157+ update_smart_mask (method, output_indices , update_len);
124158 }
125- valid_len_ += update_len;
126159 }
127160
128161 /* *
129162 * Reset the cache. After this the cache contains no valid data and is ready
130- * for number of tokens up to the context length.
163+ * for number of tokens up to the cache length.
131164 */
132165 void reset () {
133166 valid_len_ = 0 ;
134- if (transpose_) {
135- throw std::runtime_error (" Not implemented." );
136- } else {
137- initSeqDim ();
167+ if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE) {
168+ init_ptrs ();
138169 }
139170 }
140171
141172 private:
142- void initSeqDim () {
143- auto cacheSize = cache_len_ * head_dim_;
173+ void init_ptrs () {
144174 input_ptrs_.resize (n_caches_);
145175 output_ptrs_.resize (n_caches_);
146176 for (size_t i = 0 ; i < n_caches_; i++) {
147- input_ptrs_[i] = data_ + i * cacheSize ;
148- output_ptrs_[i] = input_ptrs_[i] + cacheSize ;
177+ input_ptrs_[i] = cache_data_ + i * cache_len_ * head_dim_ ;
178+ output_ptrs_[i] = update_data_ + i * max_input_len_ * head_dim_ ;
149179 }
150180 }
151181
152- void updateSeqDim (
182+ void update_sliding_cache (
153183 torch::executor::Method& method,
154- const std::vector<size_t >& outputIndices ,
184+ const std::vector<size_t >& output_indices ,
155185 size_t update_len) {
156- ET_CHECK (n_caches_ == outputIndices .size ());
186+ ET_CHECK (n_caches_ == output_indices .size ());
157187 for (size_t i = 0 ; i < n_caches_; i++) {
158- const auto & updateTensor = method.get_output (outputIndices[i]).toTensor ();
159- ET_CHECK (
160- input_ptrs_[i] + cache_len_ * head_dim_ ==
161- updateTensor.mutable_data_ptr <T>());
162-
188+ const auto & updateTensor =
189+ method.get_output (output_indices[i]).toTensor ();
190+ ET_CHECK (output_ptrs_[i] == updateTensor.const_data_ptr <T>());
191+ std::copy (
192+ output_ptrs_[i],
193+ output_ptrs_[i] + update_len * head_dim_,
194+ input_ptrs_[i] + cache_len_ * head_dim_);
163195 input_ptrs_[i] += update_len * head_dim_;
164- output_ptrs_[i] += update_len * head_dim_;
165196 }
197+ valid_len_ += update_len;
198+ }
199+
200+ void update_smart_mask (
201+ torch::executor::Method& method,
202+ const std::vector<size_t >& output_indices,
203+ size_t update_len) {
204+ for (size_t i = 0 ; i < n_caches_; i++) {
205+ const auto & updateTensor =
206+ method.get_output (output_indices[i]).toTensor ();
207+ ET_CHECK (output_ptrs_[i] == updateTensor.mutable_data_ptr <T>());
208+ std::copy (
209+ output_ptrs_[i],
210+ output_ptrs_[i] + update_len * head_dim_,
211+ input_ptrs_[i] + valid_len_ * head_dim_);
212+ }
213+ valid_len_ += update_len;
166214 }
167215
168- // std::vector<T> pool_;
169216 size_t n_caches_;
170217 size_t cache_len_;
171218 size_t max_input_len_;
172219 size_t head_dim_;
173220 bool transpose_;
221+ StaticAttentionUpdateStyle style_;
174222 AllocatorT allocator_;
175- size_t data_size_;
176- T* data_;
223+ size_t cache_data_size_;
224+ T* cache_data_;
225+ size_t update_data_size_;
226+ T* update_data_;
177227 std::vector<T*> input_ptrs_;
178228 std::vector<T*> output_ptrs_;
179229 size_t valid_len_ = 0 ;
@@ -183,28 +233,30 @@ template <typename T, typename AllocatorT = std::allocator<T>>
183233class StaticAttentionMask {
184234 public:
185235 /* *
186- * Manages the attention mask in the same style of KV cache IO where valid
187- * data is at the end of the cache. The mask has shape (1, maxSeqLen,
188- * cache_len
189- * + maxSeqLen) where maxSeqLen is 1 for decode or the prefill length. Accepts
190- * zero_val and mask_val (which represents -inf) to support quantized mask.
236+ * Manages the attention mask for StaticKVCache. Create one mask for each
237+ * input length. Accepts zero_val and mask_val (which represents -inf) to
238+ * support quantized mask.
191239 *
192- * This class manages the slice of the mask at [:, :, : (cache_len -
193- * validCacheLen)]. User can update the rest of the mask to implement causal
194- * masking for example.
240+ * The mask shape is (1, input_len, cache_len + input_len). This class manages
241+ * the slice of the mask at [:, :, :cache_len] to only allow valid cache
242+ * elements to participate in the attention. User can update the rest of the
243+ * mask (to implement causal mask for example).
195244 */
196245 StaticAttentionMask (
197246 size_t cache_len,
198247 size_t input_len,
199248 size_t head_dim,
200249 T zero_val,
201- T mask_val)
250+ T mask_val,
251+ StaticAttentionUpdateStyle style =
252+ StaticAttentionUpdateStyle::SLIDING_CACHE)
202253 : cache_len_(cache_len),
203254 input_len_ (input_len),
204255 head_dim_(head_dim),
205- cache_mask_len_(cache_len_ ),
256+ cache_valid_len_( 0 ),
206257 zero_val_(zero_val),
207- mask_val_(mask_val) {
258+ mask_val_(mask_val),
259+ style_(style) {
208260 data_size_ = input_len_ * (cache_len_ + input_len_);
209261 data_ = allocator_.allocate (data_size_);
210262 ET_CHECK (data_ != nullptr );
@@ -224,7 +276,7 @@ class StaticAttentionMask {
224276 * Reset the mask to the state where the cache contains no valid data.
225277 */
226278 void reset () {
227- cache_mask_len_ = cache_len_ ;
279+ cache_valid_len_ = 0 ;
228280 for (size_t i = 0 ; i < input_len_; i++) {
229281 auto * p = data_ + (cache_len_ + input_len_) * i;
230282 std::fill (p, p + cache_len_, mask_val_);
@@ -233,19 +285,29 @@ class StaticAttentionMask {
233285
234286 /* *
235287 * Update the mask to indicate update_len elements have been added to the
236- * cache. Note that update_len might be smaller than maxSeqLen when prefilling
237- * with padded inputs.
288+ * cache. Note that update_len might be smaller than input_len_ when
289+ * prefilling with padded inputs.
238290 */
239- void updateCacheMask (size_t update_len) {
240- for (size_t i = 0 ; i < input_len_; i++) {
241- auto * p = data_ + (cache_len_ + input_len_) * i;
242- std::fill (
243- p + cache_mask_len_ - update_len, p + cache_mask_len_, zero_val_);
291+ void unmask (size_t update_len) {
292+ if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE) {
293+ for (size_t i = 0 ; i < input_len_; i++) {
294+ auto * p = data_ + (cache_len_ + input_len_) * i;
295+ std::fill (
296+ p + cache_len_ - cache_valid_len_ - update_len,
297+ p + cache_len_ - cache_valid_len_,
298+ zero_val_);
299+ }
300+ } else {
301+ for (size_t i = 0 ; i < input_len_; i++) {
302+ auto * p = data_ + (cache_len_ + input_len_) * i;
303+ std::fill (
304+ p + cache_valid_len_, p + cache_valid_len_ + update_len, zero_val_);
305+ }
244306 }
245- cache_mask_len_ - = update_len;
307+ cache_valid_len_ + = update_len;
246308 }
247309
248- void setCausalMask () {
310+ void set_causal_mask () {
249311 for (size_t i = 0 ; i < input_len_ - 1 ; i++) {
250312 auto * p = data_ + (cache_len_ + input_len_) * i;
251313 std::fill (p + cache_len_, p + cache_len_ + 1 + i, zero_val_);
@@ -261,9 +323,10 @@ class StaticAttentionMask {
261323 size_t cache_len_;
262324 size_t input_len_;
263325 size_t head_dim_;
264- size_t cache_mask_len_ ;
326+ size_t cache_valid_len_ ;
265327 T zero_val_;
266328 T mask_val_;
329+ StaticAttentionUpdateStyle style_;
267330 AllocatorT allocator_;
268331 size_t data_size_ = 0 ;
269332 T* data_;
@@ -285,7 +348,9 @@ class StaticAttentionIOManager {
285348 size_t rope_freqs_cos_index,
286349 size_t rope_freqs_sin_index,
287350 RopeT* rope_freqs_cos,
288- RopeT* rope_freqs_sin)
351+ RopeT* rope_freqs_sin,
352+ StaticAttentionUpdateStyle style =
353+ StaticAttentionUpdateStyle::SLIDING_CACHE)
289354 : cache_len_(cache_len),
290355 head_dim_ (head_dim),
291356 kCaches_(n_caches, cache_len, head_dim, max_input_len),
@@ -295,6 +360,9 @@ class StaticAttentionIOManager {
295360 rope_freqs_cos_(rope_freqs_cos),
296361 rope_freqs_sin_(rope_freqs_sin) {}
297362
363+ /* *
364+ * Create a new StaticAttentionMask that will be managed by this object.
365+ */
298366 StaticAttentionMask<MaskT, MaskAllocatorT>&
299367 addMask (size_t input_len, MaskT zero_val, MaskT mask_val) {
300368 auto it = attentionMasks_.emplace (
@@ -305,10 +373,16 @@ class StaticAttentionIOManager {
305373 return it.first ->second ;
306374 }
307375
376+ /* *
377+ * Retrieve a mask suitable for given input length.
378+ */
308379 StaticAttentionMask<MaskT, MaskAllocatorT>& getMask (size_t input_len) {
309380 return attentionMasks_.at (input_len);
310381 }
311382
383+ /* *
384+ * Set I/O pointers for KV cache and RoPE freqencies.
385+ */
312386 void prepare (
313387 torch::executor::Method& method,
314388 const std::vector<size_t >& k_cache_input_indices,
@@ -327,6 +401,10 @@ class StaticAttentionIOManager {
327401 rope_freqs_sin_ + input_pos_ * head_dim_ / 2 );
328402 }
329403
404+ /* *
405+ * Update all caches and masks under management to reflect that model produced
406+ * update_len new elements.
407+ */
330408 void update (
331409 torch::executor::Method& method,
332410 const std::vector<size_t >& k_cache_output_indices,
@@ -336,10 +414,13 @@ class StaticAttentionIOManager {
336414 kCaches_ .update (method, k_cache_output_indices, update_len);
337415 vCaches_.update (method, v_cache_output_indices, update_len);
338416 for (auto & it : attentionMasks_) {
339- it.second .updateCacheMask (update_len);
417+ it.second .unmask (update_len);
340418 }
341419 }
342420
421+ /* *
422+ * Reset all caches and masks under management.
423+ */
343424 void reset () {
344425 input_pos_ = 0 ;
345426 kCaches_ .reset ();
0 commit comments