1010// LLM.
1111
1212#include < executorch/extension/llm/runner/text_prefiller.h>
13+ #include < algorithm>
1314
1415namespace executorch {
1516namespace extension {
@@ -18,10 +19,12 @@ namespace llm {
1819TextPrefiller::TextPrefiller (
1920 TextDecoderRunner* text_decoder_runner,
2021 bool use_kv_cache,
21- bool enable_parallel_prefill)
22+ bool enable_parallel_prefill,
23+ int64_t max_seq_len)
2224 : text_decoder_runner_(text_decoder_runner),
2325 use_kv_cache_ (use_kv_cache),
24- enable_parallel_prefill_(enable_parallel_prefill) {}
26+ enable_parallel_prefill_(enable_parallel_prefill),
27+ max_seq_len_(max_seq_len > 0 ? max_seq_len - 1 : 127 ) {} // -1 because for some reason tracing results in this upperbound
2528
2629::executorch::runtime::Result<uint64_t > TextPrefiller::prefill (
2730 std::vector<uint64_t >& prompt_tokens,
@@ -30,6 +33,43 @@ ::executorch::runtime::Result<uint64_t> TextPrefiller::prefill(
3033 if (!text_decoder_runner_->is_method_loaded ()) {
3134 ET_CHECK_OK_OR_RETURN_ERROR (text_decoder_runner_->load ());
3235 }
36+
37+ // Check if we need to chunk the prompt tokens
38+ int32_t num_prompt_tokens = prompt_tokens.size ();
39+
40+ // If prompt tokens exceed max_seq_len_, we need to chunk them
41+ if (num_prompt_tokens > max_seq_len_) {
42+ uint64_t cur_token = 0 ;
43+ int num_tokens_to_process = 0 ;
44+
45+ while (num_tokens_to_process < num_prompt_tokens) {
46+ auto num_tokens_to_prefill_with =
47+ std::min<int >(num_prompt_tokens - num_tokens_to_process, max_seq_len_);
48+
49+ std::vector<uint64_t > prompt_tokens_to_process (num_tokens_to_prefill_with);
50+ std::copy (
51+ prompt_tokens.begin () + num_tokens_to_process,
52+ prompt_tokens.begin () + num_tokens_to_process + num_tokens_to_prefill_with,
53+ prompt_tokens_to_process.begin ());
54+
55+ // Process this chunk
56+ auto chunk_result = prefillChunk (prompt_tokens_to_process, start_pos);
57+ ET_CHECK_OK_OR_RETURN_ERROR (chunk_result.error ());
58+ cur_token = chunk_result.get ();
59+
60+ num_tokens_to_process += num_tokens_to_prefill_with;
61+ }
62+
63+ return cur_token;
64+ } else {
65+ // If prompt tokens don't exceed max_seq_len_, process them directly
66+ return prefillChunk (prompt_tokens, start_pos);
67+ }
68+ }
69+
70+ ::executorch::runtime::Result<uint64_t > TextPrefiller::prefillChunk (
71+ std::vector<uint64_t >& prompt_tokens,
72+ int64_t & start_pos) {
3373 // enable_parallel_prefill_ maybe set even when not using kv cache
3474 // When kv cache is not used, start pos is ignored
3575 int32_t num_prompt_tokens = prompt_tokens.size ();
0 commit comments