Skip to content

Commit 9d92ff8

Browse files
authored
StaticAttentionIOManager: prefill helper
Differential Revision: D79102521 Pull Request resolved: #12904
1 parent 02e50cc commit 9d92ff8

File tree

1 file changed

+38
-0
lines changed

1 file changed

+38
-0
lines changed

examples/models/llama/runner/static_attention_io_manager.h

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -602,6 +602,39 @@ class StaticAttentionIOManager {
602602
}
603603
}
604604

605+
/**
606+
* Prefill helper. Run multiple inferences as needed depending on the length
607+
* of the prompt and method's input length. Returns the position in the output
608+
* that corresponds to the end of the prompt during the last inference.
609+
*/
610+
template <typename TokenT>
611+
size_t prefill(
612+
executorch::runtime::Span<TokenT> tokens,
613+
executorch::runtime::Span<TokenT> input_buffer,
614+
executorch::runtime::Method& method) {
615+
size_t input_len = input_buffer.size();
616+
get_mask(input_buffer.size()).set_causal_mask();
617+
618+
size_t batch_len = 0;
619+
for (size_t i = 0; i < tokens.size(); i += input_len) {
620+
batch_len = std::min(input_len, tokens.size() - i);
621+
std::copy(&tokens[i], &tokens[i + batch_len], input_buffer.begin());
622+
prepare(method);
623+
ET_CHECK(method.execute() == executorch::runtime::Error::Ok);
624+
update(
625+
method,
626+
config_.k_cache_output_indices,
627+
config_.v_cache_output_indices,
628+
batch_len);
629+
}
630+
return batch_len - 1;
631+
}
632+
633+
/**
634+
* Decode helper. The `sample` argument is called after each inference and
635+
* should retrieve the logits from the `method` argument's output and return
636+
* the sampled token.
637+
*/
605638
template <typename TokenT>
606639
std::vector<TokenT> decode(
607640
TokenT prev_tok,
@@ -632,6 +665,11 @@ class StaticAttentionIOManager {
632665
return generated_tokens;
633666
}
634667

668+
/**
669+
* Lookahead decode helper. The `sample` argument is called after each
670+
* inference and should retrieve the logits from the `method` argument's
671+
* output and return the sampled token for all output positions.
672+
*/
635673
template <typename TokenT>
636674
std::vector<TokenT> lookahead_decode(
637675
TokenT prev_tok,

0 commit comments

Comments
 (0)