diff --git a/examples/models/llama/runner/static_attention_io_manager.h b/examples/models/llama/runner/static_attention_io_manager.h index c1d6ad9bb07..6a7559bc5bd 100644 --- a/examples/models/llama/runner/static_attention_io_manager.h +++ b/examples/models/llama/runner/static_attention_io_manager.h @@ -168,6 +168,10 @@ class StaticKVCache { } } + size_t size() { + return valid_len_; + } + private: void init_ptrs() { input_ptrs_.resize(n_caches_); @@ -339,26 +343,40 @@ template < typename MaskAllocatorT = std::allocator> class StaticAttentionIOManager { public: - StaticAttentionIOManager( - size_t n_caches, - size_t cache_len, - size_t head_dim, - size_t max_input_len, - size_t rope_freqs_cos_index, - size_t rope_freqs_sin_index, - RopeT* rope_freqs_cos, - RopeT* rope_freqs_sin, - StaticAttentionUpdateStyle style = - StaticAttentionUpdateStyle::SLIDING_CACHE) - : cache_len_(cache_len), - head_dim_(head_dim), - style_(style), - kCaches_(n_caches, cache_len, head_dim, max_input_len, false, style), - vCaches_(n_caches, cache_len, head_dim, max_input_len, false, style), - rope_freqs_cos_index_(rope_freqs_cos_index), - rope_freqs_sin_index_(rope_freqs_sin_index), - rope_freqs_cos_(rope_freqs_cos), - rope_freqs_sin_(rope_freqs_sin) {} + struct StaticAttentionIOConfig { + size_t n_caches{}; + size_t cache_len{}; + size_t head_dim{}; + size_t max_input_len{}; + size_t attn_mask_input_index{}; + size_t rope_freqs_cos_input_index{}; + size_t rope_freqs_sin_input_index{}; + std::vector k_cache_input_indices; + std::vector k_cache_output_indices; + std::vector v_cache_input_indices; + std::vector v_cache_output_indices; + RopeT* rope_freqs_cos; + RopeT* rope_freqs_sin; + StaticAttentionUpdateStyle style = + StaticAttentionUpdateStyle::SLIDING_CACHE; + }; + + StaticAttentionIOManager(StaticAttentionIOConfig config) + : config_(std::move(config)), + kCaches_( + config_.n_caches, + config_.cache_len, + config_.head_dim, + config_.max_input_len, + false, + config_.style), + vCaches_( + config_.n_caches, + config_.cache_len, + config_.head_dim, + config_.max_input_len, + false, + config_.style) {} /** * Create a new StaticAttentionMask that will be managed by this object. @@ -369,36 +387,38 @@ class StaticAttentionIOManager { std::piecewise_construct, std::forward_as_tuple(input_len), std::forward_as_tuple( - cache_len_, input_len, head_dim_, zero_val, mask_val, style_)); + config_.cache_len, + input_len, + config_.head_dim, + zero_val, + mask_val, + config_.style)); return it.first->second; } /** * Retrieve a mask suitable for given input length. */ - StaticAttentionMask& getMask(size_t input_len) { + StaticAttentionMask& get_mask(size_t input_len) { return attentionMasks_.at(input_len); } /** * Set I/O pointers for KV cache and RoPE freqencies. */ - void prepare( - torch::executor::Method& method, - const std::vector& k_cache_input_indices, - const std::vector& k_cache_output_indices, - const std::vector& v_cache_input_indices, - const std::vector& v_cache_output_indices) { - kCaches_.prepare(method, k_cache_input_indices, k_cache_output_indices); - vCaches_.prepare(method, v_cache_input_indices, v_cache_output_indices); + void prepare(torch::executor::Method& method) { + kCaches_.prepare( + method, config_.k_cache_input_indices, config_.k_cache_output_indices); + vCaches_.prepare( + method, config_.v_cache_input_indices, config_.v_cache_output_indices); set_input( method, - rope_freqs_cos_index_, - rope_freqs_cos_ + input_pos_ * head_dim_ / 2); + config_.rope_freqs_cos_input_index, + config_.rope_freqs_cos + input_pos_ * config_.head_dim / 2); set_input( method, - rope_freqs_sin_index_, - rope_freqs_sin_ + input_pos_ * head_dim_ / 2); + config_.rope_freqs_sin_input_index, + config_.rope_freqs_sin + input_pos_ * config_.head_dim / 2); } /** @@ -430,6 +450,36 @@ class StaticAttentionIOManager { } } + template + std::vector decode( + TokenT prev_tok, + executorch::runtime::Span input_buffer, + executorch::runtime::Method& method, + std::function& sample, + std::function& should_stop) { + set_input(method, 0, input_buffer.data()); + auto& mask = get_mask(input_buffer.size()); + set_input(method, config_.attn_mask_input_index, mask.get()); + + std::vector generated_tokens; + while (kCaches_.size() + 1 <= config_.cache_len) { + input_buffer[0] = prev_tok; + prepare(method); + ET_CHECK(method.execute() == executorch::runtime::Error::Ok); + update( + method, + config_.k_cache_output_indices, + config_.v_cache_output_indices, + 1); + prev_tok = sample(method); + generated_tokens.emplace_back(prev_tok); + if (should_stop(prev_tok)) { + break; + } + } + return generated_tokens; + } + private: template void set_input(executorch::runtime::Method& method, size_t idx, T* data) { @@ -447,19 +497,12 @@ class StaticAttentionIOManager { ET_CHECK(method.set_input(t, idx) == executorch::runtime::Error::Ok); } - size_t cache_len_; - size_t input_len_; - size_t head_dim_; + StaticAttentionIOConfig config_; size_t input_pos_; - StaticAttentionUpdateStyle style_; StaticKVCache kCaches_; StaticKVCache vCaches_; std::unordered_map> attentionMasks_; - size_t rope_freqs_cos_index_; - size_t rope_freqs_sin_index_; - RopeT* rope_freqs_cos_; - RopeT* rope_freqs_sin_; }; } // namespace example