@@ -602,6 +602,39 @@ class StaticAttentionIOManager {
602602 }
603603 }
604604
605+ /* *
606+ * Prefill helper. Run multiple inferences as needed depending on the length
607+ * of the prompt and method's input length. Returns the position in the output
608+ * that corresponds to the end of the prompt during the last inference.
609+ */
610+ template <typename TokenT>
611+ size_t prefill (
612+ executorch::runtime::Span<TokenT> tokens,
613+ executorch::runtime::Span<TokenT> input_buffer,
614+ executorch::runtime::Method& method) {
615+ size_t input_len = input_buffer.size ();
616+ get_mask (input_buffer.size ()).set_causal_mask ();
617+
618+ size_t batch_len = 0 ;
619+ for (size_t i = 0 ; i < tokens.size (); i += input_len) {
620+ batch_len = std::min (input_len, tokens.size () - i);
621+ std::copy (&tokens[i], &tokens[i + batch_len], input_buffer.begin ());
622+ prepare (method);
623+ ET_CHECK (method.execute () == executorch::runtime::Error::Ok);
624+ update (
625+ method,
626+ config_.k_cache_output_indices ,
627+ config_.v_cache_output_indices ,
628+ batch_len);
629+ }
630+ return batch_len - 1 ;
631+ }
632+
633+ /* *
634+ * Decode helper. The `sample` argument is called after each inference and
635+ * should retrieve the logits from the `method` argument's output and return
636+ * the sampled token.
637+ */
605638 template <typename TokenT>
606639 std::vector<TokenT> decode (
607640 TokenT prev_tok,
@@ -632,6 +665,11 @@ class StaticAttentionIOManager {
632665 return generated_tokens;
633666 }
634667
668+ /* *
669+ * Lookahead decode helper. The `sample` argument is called after each
670+ * inference and should retrieve the logits from the `method` argument's
671+ * output and return the sampled token for all output positions.
672+ */
635673 template <typename TokenT>
636674 std::vector<TokenT> lookahead_decode (
637675 TokenT prev_tok,
0 commit comments