@@ -168,6 +168,10 @@ class StaticKVCache {
168168 }
169169 }
170170
171+ size_t size () {
172+ return valid_len_;
173+ }
174+
171175 private:
172176 void init_ptrs () {
173177 input_ptrs_.resize (n_caches_);
@@ -339,26 +343,40 @@ template <
339343 typename MaskAllocatorT = std::allocator<MaskT>>
340344class StaticAttentionIOManager {
341345 public:
342- StaticAttentionIOManager (
343- size_t n_caches,
344- size_t cache_len,
345- size_t head_dim,
346- size_t max_input_len,
347- size_t rope_freqs_cos_index,
348- size_t rope_freqs_sin_index,
349- RopeT* rope_freqs_cos,
350- RopeT* rope_freqs_sin,
351- StaticAttentionUpdateStyle style =
352- StaticAttentionUpdateStyle::SLIDING_CACHE)
353- : cache_len_(cache_len),
354- head_dim_ (head_dim),
355- style_(style),
356- kCaches_(n_caches, cache_len, head_dim, max_input_len, false , style),
357- vCaches_(n_caches, cache_len, head_dim, max_input_len, false , style),
358- rope_freqs_cos_index_(rope_freqs_cos_index),
359- rope_freqs_sin_index_(rope_freqs_sin_index),
360- rope_freqs_cos_(rope_freqs_cos),
361- rope_freqs_sin_(rope_freqs_sin) {}
346+ struct StaticAttentionIOConfig {
347+ size_t n_caches{};
348+ size_t cache_len{};
349+ size_t head_dim{};
350+ size_t max_input_len{};
351+ size_t attn_mask_input_index{};
352+ size_t rope_freqs_cos_input_index{};
353+ size_t rope_freqs_sin_input_index{};
354+ std::vector<size_t > k_cache_input_indices;
355+ std::vector<size_t > k_cache_output_indices;
356+ std::vector<size_t > v_cache_input_indices;
357+ std::vector<size_t > v_cache_output_indices;
358+ RopeT* rope_freqs_cos;
359+ RopeT* rope_freqs_sin;
360+ StaticAttentionUpdateStyle style =
361+ StaticAttentionUpdateStyle::SLIDING_CACHE;
362+ };
363+
364+ StaticAttentionIOManager (StaticAttentionIOConfig config)
365+ : config_(std::move(config)),
366+ kCaches_ (
367+ config_.n_caches,
368+ config_.cache_len,
369+ config_.head_dim,
370+ config_.max_input_len,
371+ false ,
372+ config_.style),
373+ vCaches_(
374+ config_.n_caches,
375+ config_.cache_len,
376+ config_.head_dim,
377+ config_.max_input_len,
378+ false ,
379+ config_.style) {}
362380
363381 /* *
364382 * Create a new StaticAttentionMask that will be managed by this object.
@@ -369,36 +387,38 @@ class StaticAttentionIOManager {
369387 std::piecewise_construct,
370388 std::forward_as_tuple (input_len),
371389 std::forward_as_tuple (
372- cache_len_, input_len, head_dim_, zero_val, mask_val, style_));
390+ config_.cache_len ,
391+ input_len,
392+ config_.head_dim ,
393+ zero_val,
394+ mask_val,
395+ config_.style ));
373396 return it.first ->second ;
374397 }
375398
376399 /* *
377400 * Retrieve a mask suitable for given input length.
378401 */
379- StaticAttentionMask<MaskT, MaskAllocatorT>& getMask (size_t input_len) {
402+ StaticAttentionMask<MaskT, MaskAllocatorT>& get_mask (size_t input_len) {
380403 return attentionMasks_.at (input_len);
381404 }
382405
383406 /* *
384407 * Set I/O pointers for KV cache and RoPE freqencies.
385408 */
386- void prepare (
387- torch::executor::Method& method,
388- const std::vector<size_t >& k_cache_input_indices,
389- const std::vector<size_t >& k_cache_output_indices,
390- const std::vector<size_t >& v_cache_input_indices,
391- const std::vector<size_t >& v_cache_output_indices) {
392- kCaches_ .prepare (method, k_cache_input_indices, k_cache_output_indices);
393- vCaches_.prepare (method, v_cache_input_indices, v_cache_output_indices);
409+ void prepare (torch::executor::Method& method) {
410+ kCaches_ .prepare (
411+ method, config_.k_cache_input_indices , config_.k_cache_output_indices );
412+ vCaches_.prepare (
413+ method, config_.v_cache_input_indices , config_.v_cache_output_indices );
394414 set_input (
395415 method,
396- rope_freqs_cos_index_ ,
397- rope_freqs_cos_ + input_pos_ * head_dim_ / 2 );
416+ config_. rope_freqs_cos_input_index ,
417+ config_. rope_freqs_cos + input_pos_ * config_. head_dim / 2 );
398418 set_input (
399419 method,
400- rope_freqs_sin_index_ ,
401- rope_freqs_sin_ + input_pos_ * head_dim_ / 2 );
420+ config_. rope_freqs_sin_input_index ,
421+ config_. rope_freqs_sin + input_pos_ * config_. head_dim / 2 );
402422 }
403423
404424 /* *
@@ -430,6 +450,36 @@ class StaticAttentionIOManager {
430450 }
431451 }
432452
453+ template <typename TokenT>
454+ std::vector<TokenT> decode (
455+ TokenT prev_tok,
456+ executorch::runtime::Span<TokenT> input_buffer,
457+ executorch::runtime::Method& method,
458+ std::function<TokenT(executorch::runtime::Method&)>& sample,
459+ std::function<bool(TokenT)>& should_stop) {
460+ set_input (method, 0 , input_buffer.data ());
461+ auto & mask = get_mask (input_buffer.size ());
462+ set_input (method, config_.attn_mask_input_index , mask.get ());
463+
464+ std::vector<TokenT> generated_tokens;
465+ while (kCaches_ .size () + 1 <= config_.cache_len ) {
466+ input_buffer[0 ] = prev_tok;
467+ prepare (method);
468+ ET_CHECK (method.execute () == executorch::runtime::Error::Ok);
469+ update (
470+ method,
471+ config_.k_cache_output_indices ,
472+ config_.v_cache_output_indices ,
473+ 1 );
474+ prev_tok = sample (method);
475+ generated_tokens.emplace_back (prev_tok);
476+ if (should_stop (prev_tok)) {
477+ break ;
478+ }
479+ }
480+ return generated_tokens;
481+ }
482+
433483 private:
434484 template <typename T>
435485 void set_input (executorch::runtime::Method& method, size_t idx, T* data) {
@@ -447,19 +497,12 @@ class StaticAttentionIOManager {
447497 ET_CHECK (method.set_input (t, idx) == executorch::runtime::Error::Ok);
448498 }
449499
450- size_t cache_len_;
451- size_t input_len_;
452- size_t head_dim_;
500+ StaticAttentionIOConfig config_;
453501 size_t input_pos_;
454- StaticAttentionUpdateStyle style_;
455502 StaticKVCache<CacheT, CacheAllocatorT> kCaches_ ;
456503 StaticKVCache<CacheT, CacheAllocatorT> vCaches_;
457504 std::unordered_map<size_t , StaticAttentionMask<MaskT, MaskAllocatorT>>
458505 attentionMasks_;
459- size_t rope_freqs_cos_index_;
460- size_t rope_freqs_sin_index_;
461- RopeT* rope_freqs_cos_;
462- RopeT* rope_freqs_sin_;
463506};
464507
465508} // namespace example
0 commit comments