diff --git a/examples/models/llama/runner/static_attention_io_manager.h b/examples/models/llama/runner/static_attention_io_manager.h index 6a7559bc5bd..b077f414f02 100644 --- a/examples/models/llama/runner/static_attention_io_manager.h +++ b/examples/models/llama/runner/static_attention_io_manager.h @@ -6,6 +6,7 @@ * LICENSE file in the root directory of this source tree. */ +#include #include #include #include @@ -13,6 +14,7 @@ #include #include +#include namespace example { @@ -145,15 +147,16 @@ class StaticKVCache { void update( torch::executor::Method& method, const std::vector& output_indices, - size_t update_len) { + size_t update_len, + size_t update_pos = 0) { if (valid_len_ + update_len > cache_len_) { throw std::runtime_error("Cache capacity exceeded."); } if (style_ == StaticAttentionUpdateStyle::SLIDING_CACHE) { - update_sliding_cache(method, output_indices, update_len); + update_sliding_cache(method, output_indices, update_len, update_pos); } else { - update_smart_mask(method, output_indices, update_len); + update_smart_mask(method, output_indices, update_len, update_pos); } } @@ -185,15 +188,16 @@ class StaticKVCache { void update_sliding_cache( torch::executor::Method& method, const std::vector& output_indices, - size_t update_len) { + size_t update_len, + size_t update_pos) { ET_CHECK(n_caches_ == output_indices.size()); for (size_t i = 0; i < n_caches_; i++) { const auto& updateTensor = method.get_output(output_indices[i]).toTensor(); ET_CHECK(output_ptrs_[i] == updateTensor.const_data_ptr()); std::copy( - output_ptrs_[i], - output_ptrs_[i] + update_len * head_dim_, + output_ptrs_[i] + update_pos * head_dim_, + output_ptrs_[i] + (update_pos + update_len) * head_dim_, input_ptrs_[i] + cache_len_ * head_dim_); input_ptrs_[i] += update_len * head_dim_; } @@ -203,14 +207,15 @@ class StaticKVCache { void update_smart_mask( torch::executor::Method& method, const std::vector& output_indices, - size_t update_len) { + size_t update_len, + size_t update_pos) { for (size_t i = 0; i < n_caches_; i++) { const auto& updateTensor = method.get_output(output_indices[i]).toTensor(); ET_CHECK(output_ptrs_[i] == updateTensor.mutable_data_ptr()); std::copy( - output_ptrs_[i], - output_ptrs_[i] + update_len * head_dim_, + output_ptrs_[i] + update_pos * head_dim_, + output_ptrs_[i] + (update_pos + update_len) * head_dim_, input_ptrs_[i] + valid_len_ * head_dim_); } valid_len_ += update_len; @@ -322,6 +327,14 @@ class StaticAttentionMask { return data_; } + T zero_val() { + return zero_val_; + } + + T mask_val() { + return mask_val_; + } + private: size_t cache_len_; size_t input_len_; @@ -335,6 +348,59 @@ class StaticAttentionMask { T* data_; }; +template +class SuffixCache { + public: + SuffixCache(size_t n, size_t capacity) + : n_(n), capacity_(capacity), pos_(0), cache_(n_ * capacity_) {} + + void add(executorch::runtime::Span suffix) { + if (suffix.size() != n_ - 1) { + throw std::runtime_error("Wrong suffix length."); + } + for (size_t i = 0; i < capacity_; i++) { + auto* p = cache_.data() + (n_ - 1) * i; + if (std::equal(p, p + (n_ - 1), suffix.begin())) { + return; + } + } + auto* dst = cache_.data() + (n_ - 1) * pos_; + std::copy(suffix.begin(), suffix.end(), dst); + pos_ = (pos_ + 1) % capacity_; + } + + auto begin() { + return cache_.begin(); + } + auto end() { + return cache_.end(); + } + auto begin() const { + return cache_.begin(); + } + auto end() const { + return cache_.end(); + } + + static void seed_suffix_caches( + std::unordered_map>& suffix_caches, + executorch::runtime::Span toks, + size_t ngram_size, + size_t cache_size) { + for (size_t i = 0; i + ngram_size < toks.size(); i++) { + auto& cache = suffix_caches.try_emplace(toks[i], ngram_size, cache_size) + .first->second; + cache.add(executorch::runtime::Span(&toks[i + 1], ngram_size - 1)); + } + } + + private: + size_t n_; + size_t capacity_; + size_t pos_; + std::vector cache_; +}; + template < typename CacheT, typename MaskT, @@ -376,7 +442,14 @@ class StaticAttentionIOManager { config_.head_dim, config_.max_input_len, false, - config_.style) {} + config_.style) { + ET_LOG( + Info, + "Created StaticAttentionIOManager with" + " max input length = %zu cache length = %zu", + config_.max_input_len, + config_.cache_len); + } /** * Create a new StaticAttentionMask that will be managed by this object. @@ -406,19 +479,48 @@ class StaticAttentionIOManager { /** * Set I/O pointers for KV cache and RoPE freqencies. */ - void prepare(torch::executor::Method& method) { + void prepare( + torch::executor::Method& method, + std::optional> pos_offsets = + std::nullopt) { kCaches_.prepare( method, config_.k_cache_input_indices, config_.k_cache_output_indices); vCaches_.prepare( method, config_.v_cache_input_indices, config_.v_cache_output_indices); - set_input( - method, - config_.rope_freqs_cos_input_index, - config_.rope_freqs_cos + input_pos_ * config_.head_dim / 2); - set_input( - method, - config_.rope_freqs_sin_input_index, - config_.rope_freqs_sin + input_pos_ * config_.head_dim / 2); + + size_t rope_dim = config_.head_dim / 2; + if (pos_offsets) { + rope_freqs_cos_override_.clear(); + rope_freqs_sin_override_.clear(); + for (auto offset : *pos_offsets) { + auto pos = input_pos_ + offset; + std::copy( + config_.rope_freqs_cos + pos * rope_dim, + config_.rope_freqs_cos + (pos + 1) * rope_dim, + std::back_inserter(rope_freqs_cos_override_)); + std::copy( + config_.rope_freqs_sin + pos * rope_dim, + config_.rope_freqs_sin + (pos + 1) * rope_dim, + std::back_inserter(rope_freqs_sin_override_)); + } + set_input( + method, + config_.rope_freqs_cos_input_index, + rope_freqs_cos_override_.data()); + set_input( + method, + config_.rope_freqs_sin_input_index, + rope_freqs_sin_override_.data()); + } else { + set_input( + method, + config_.rope_freqs_cos_input_index, + config_.rope_freqs_cos + input_pos_ * rope_dim); + set_input( + method, + config_.rope_freqs_sin_input_index, + config_.rope_freqs_sin + input_pos_ * rope_dim); + } } /** @@ -429,10 +531,13 @@ class StaticAttentionIOManager { torch::executor::Method& method, const std::vector& k_cache_output_indices, const std::vector& v_cache_output_indices, - size_t update_len) { + size_t update_len, + size_t cache_update_pos = 0) { input_pos_ += update_len; - kCaches_.update(method, k_cache_output_indices, update_len); - vCaches_.update(method, v_cache_output_indices, update_len); + kCaches_.update( + method, k_cache_output_indices, update_len, cache_update_pos); + vCaches_.update( + method, v_cache_output_indices, update_len, cache_update_pos); for (auto& it : attentionMasks_) { it.second.unmask(update_len); } @@ -480,6 +585,162 @@ class StaticAttentionIOManager { return generated_tokens; } + template + std::vector lookahead_decode( + TokenT prev_tok, + executorch::runtime::Span input_buffer, + executorch::runtime::Method& method, + std::function(executorch::runtime::Method&)>& sample, + std::function& should_stop, + size_t ngram_size, + size_t window_size, + size_t n_verifications, + std::unordered_map> suffix_caches) { + set_input(method, 0, input_buffer.data()); + size_t input_len = input_buffer.size(); + + // Set up attention mask for current input length. + auto& mask = get_mask(input_buffer.size()); + set_lookahead_decoding_mask( + mask, input_len, ngram_size, window_size, n_verifications); + set_input(method, config_.attn_mask_input_index, mask.get()); + + // Position offsets relative to current position, for indexing RoPE + // frequence tensors. + auto pos_offsets = get_lookahead_pos_offsets( + input_len, ngram_size, window_size, n_verifications); + + ET_LOG( + Info, + "Starting lookahead decoding with" + " ngram_size = %zu" + " window_size = %zu" + " n_verifications = %zu", + ngram_size, + window_size, + n_verifications); + + // Decoding loop. + std::vector generated_tokens; + size_t verification_offset = + std::max(window_size * (ngram_size - 1), static_cast(1)); + size_t n_inference = 0; + std::fill(input_buffer.begin(), input_buffer.end(), prev_tok); + while (kCaches_.size() + 1 <= config_.cache_len) { + input_buffer[0] = prev_tok; + // Initialize verification branches. + if (auto it = suffix_caches.find(prev_tok); it != suffix_caches.end()) { + auto& cache = it->second; + std::copy( + cache.begin(), + cache.end(), + input_buffer.data() + verification_offset); + } + + // Setup input pointers and RoPE frequencies. + prepare( + method, + executorch::runtime::Span(pos_offsets.data(), pos_offsets.size())); + ET_CHECK(method.execute() == executorch::runtime::Error::Ok); + n_inference++; + // Update KV caches and mask for the 1st input position. If verification + // branches produced additional matches they'll be updated seprately + // because they are not contiguous in the KV cache. + update( + method, + config_.k_cache_output_indices, + config_.v_cache_output_indices, + 1); + + auto output_toks = sample(method); + + // Collect new n-grams from lookahead branches. + std::vector new_suffix; + for (size_t i = 0; i < window_size; i++) { + new_suffix.clear(); + for (size_t j = 1; j < ngram_size - 1; j++) { + new_suffix.emplace_back(input_buffer[i + window_size * j]); + } + new_suffix.emplace_back( + output_toks[i + window_size * (ngram_size - 2)]); + + auto& cache = + suffix_caches + .try_emplace(input_buffer[i], ngram_size, n_verifications) + .first->second; + cache.add(executorch::runtime::Span(new_suffix.data(), ngram_size - 1)); + } + + // Update lookahead branches. + for (size_t i = 0; i < ngram_size - 2; i++) { + for (size_t j = 0; j < window_size; j++) { + input_buffer[window_size * i + j] = + input_buffer[window_size * (i + 1) + j]; + } + } + for (size_t j = 0; j < window_size; j++) { + input_buffer[window_size * (ngram_size - 2) + j] = + output_toks[window_size * (ngram_size - 2) + j]; + } + + // Check verification results. + std::vector longest_match; + size_t matched_branch = 0; + for (size_t i = 0; i < n_verifications; i++) { + std::vector match; + match.emplace_back(output_toks[0]); + size_t branch_offset = verification_offset + (ngram_size - 1) * i; + for (size_t j = 0; j < ngram_size - 1 && + input_buffer[branch_offset + j] == match.back(); + j++) { + match.emplace_back(output_toks[branch_offset + j]); + if (should_stop(match.back())) { + break; + } + } + if (match.size() > longest_match.size()) { + longest_match = std::move(match); + matched_branch = i; + } + } + + bool generated_stop_tok = false; + for (auto tok : longest_match) { + generated_tokens.emplace_back(tok); + if (should_stop(tok)) { + generated_stop_tok = true; + break; + } + } + + // Update KV caches and mask for additional matches. + if (longest_match.size() > 1) { + size_t branch_offset = + verification_offset + (ngram_size - 1) * matched_branch; + update( + method, + config_.k_cache_output_indices, + config_.v_cache_output_indices, + std::min( + longest_match.size() - 1, config_.cache_len - kCaches_.size()), + branch_offset); + } + + if (generated_stop_tok) { + break; + } + prev_tok = generated_tokens.back(); + } + + ET_LOG( + Info, + "Generated %zu tokens with %zu inferences(s).", + generated_tokens.size(), + n_inference); + + return generated_tokens; + } + private: template void set_input(executorch::runtime::Method& method, size_t idx, T* data) { @@ -497,12 +758,111 @@ class StaticAttentionIOManager { ET_CHECK(method.set_input(t, idx) == executorch::runtime::Error::Ok); } + void set_lookahead_decoding_mask( + StaticAttentionMask& mask, + size_t input_len, + size_t ngram_size, + size_t window_size, + size_t n_verifications) { + class SubMask { + public: + SubMask(MaskT* data, size_t stride) : data_(data), stride_(stride) {} + + MaskT& at(size_t i, size_t j = 0) { + return data_[i * stride_ + j]; + } + + private: + MaskT* data_; + size_t stride_; + }; + + size_t stride = config_.cache_len + input_len; + auto input_submask = SubMask(mask.get() + config_.cache_len, stride); + input_submask.at(0, 0) = mask.zero_val(); + + // Fill entire input mask first. + for (size_t i = 0; i < input_len; i++) { + auto* p = &input_submask.at(i); + std::fill(p, p + input_len, mask.mask_val()); + } + + auto set_causal_mask = [&](SubMask m, size_t size) { + for (size_t i = 0; i < size; i++) { + auto* p = &m.at(i); + std::fill(p, p + i + 1, mask.zero_val()); + } + }; + + auto set_diagonal_mask = [&](SubMask m, size_t size) { + for (size_t i = 0; i < size; i++) { + m.at(i, i) = mask.zero_val(); + } + }; + + // Set lookahead submasks. + for (size_t i = 0; i < ngram_size - 1; i++) { + set_causal_mask( + SubMask(&input_submask.at(window_size * i), stride), window_size); + for (size_t j = 1; j < i + 1; j++) { + set_diagonal_mask( + SubMask( + &input_submask.at(window_size * i, window_size * j), stride), + window_size); + } + } + + // Set verification submasks + size_t verification_offset = + std::max(window_size * (ngram_size - 1), static_cast(1)); + for (size_t i = 0; i < n_verifications; i++) { + size_t branch_offset = verification_offset + i * (ngram_size - 1); + set_causal_mask( + SubMask(&input_submask.at(branch_offset, branch_offset), stride), + ngram_size - 1); + } + for (size_t i = verification_offset; i < input_len; i++) { + input_submask.at(i, 0) = mask.zero_val(); + } + } + + std::vector get_lookahead_pos_offsets( + size_t input_len, + size_t ngram_size, + size_t window_size, + size_t n_verifications) { + std::vector offsets(input_len); + size_t idx = 0; + + // Lookahead branches: [i + 0, i + 1, ..., i + window_size - 1] + if (window_size > 0) { + for (size_t i = 0; i < ngram_size - 1; i++) { + for (size_t j = 0; j < window_size; j++) { + offsets[idx++] = i + j; + } + } + } else { + offsets[idx++] = 0; + } + + // Verification branches: [1, 2, ..., ngram_size - 1] + for (size_t i = 0; i < n_verifications; i++) { + for (size_t j = 1; j < ngram_size; j++) { + offsets[idx++] = j; + } + } + + return offsets; + } + StaticAttentionIOConfig config_; size_t input_pos_; StaticKVCache kCaches_; StaticKVCache vCaches_; std::unordered_map> attentionMasks_; + std::vector rope_freqs_cos_override_; + std::vector rope_freqs_sin_override_; }; } // namespace example