@@ -602,6 +602,39 @@ class StaticAttentionIOManager {
602
602
}
603
603
}
604
604
605
+ /* *
606
+ * Prefill helper. Run multiple inferences as needed depending on the length
607
+ * of the prompt and method's input length. Returns the position in the output
608
+ * that corresponds to the end of the prompt during the last inference.
609
+ */
610
+ template <typename TokenT>
611
+ size_t prefill (
612
+ executorch::runtime::Span<TokenT> tokens,
613
+ executorch::runtime::Span<TokenT> input_buffer,
614
+ executorch::runtime::Method& method) {
615
+ size_t input_len = input_buffer.size ();
616
+ get_mask (input_buffer.size ()).set_causal_mask ();
617
+
618
+ size_t batch_len = 0 ;
619
+ for (size_t i = 0 ; i < tokens.size (); i += input_len) {
620
+ batch_len = std::min (input_len, tokens.size () - i);
621
+ std::copy (&tokens[i], &tokens[i + batch_len], input_buffer.begin ());
622
+ prepare (method);
623
+ ET_CHECK (method.execute () == executorch::runtime::Error::Ok);
624
+ update (
625
+ method,
626
+ config_.k_cache_output_indices ,
627
+ config_.v_cache_output_indices ,
628
+ batch_len);
629
+ }
630
+ return batch_len - 1 ;
631
+ }
632
+
633
+ /* *
634
+ * Decode helper. The `sample` argument is called after each inference and
635
+ * should retrieve the logits from the `method` argument's output and return
636
+ * the sampled token.
637
+ */
605
638
template <typename TokenT>
606
639
std::vector<TokenT> decode (
607
640
TokenT prev_tok,
@@ -632,6 +665,11 @@ class StaticAttentionIOManager {
632
665
return generated_tokens;
633
666
}
634
667
668
+ /* *
669
+ * Lookahead decode helper. The `sample` argument is called after each
670
+ * inference and should retrieve the logits from the `method` argument's
671
+ * output and return the sampled token for all output positions.
672
+ */
635
673
template <typename TokenT>
636
674
std::vector<TokenT> lookahead_decode (
637
675
TokenT prev_tok,
0 commit comments