1111
1212#include < executorch/examples/models/llama/runner/runner.h>
1313
14+ #include < algorithm>
1415#include < ctime>
1516
1617#include < executorch/extension/llm/runner/util.h>
@@ -221,11 +222,11 @@ Error Runner::generate(
221222
222223 ET_CHECK_MSG (num_prompt_tokens >= 1 , " Expected at least 1 prompt token" );
223224 ET_CHECK_MSG (
224- num_prompt_tokens < metadata_.at (kMaxSeqLen ),
225+ num_prompt_tokens < metadata_.at (kMaxContextLen ),
225226 " num_prompt_tokens %d >= max_seq_len_ %" PRId64
226227 " , Max seq length exceeded - please increase max seq len value in .../llama2/model.py" ,
227228 num_prompt_tokens,
228- metadata_.at (kMaxSeqLen ));
229+ metadata_.at (kMaxContextLen ));
229230 ET_CHECK_MSG (
230231 num_prompt_tokens < seq_len,
231232 " num_prompt_tokens %d >= seq_len %d, Sequence length exceeded - please increase the seq_len value passed to generate()" ,
@@ -241,11 +242,26 @@ Error Runner::generate(
241242 wrapped_callback (prompt);
242243 }
243244 int64_t pos = 0 ;
244- auto prefill_res = text_prefiller_->prefill (prompt_tokens, pos);
245+ uint64_t cur_token;
246+ int max_seq_len = metadata_.at (kMaxSeqLen ) -
247+ 1 ; // -1 because for some reason tracing results in this upperbound
248+ int num_tokens_to_process = 0 ;
249+ while (num_tokens_to_process < num_prompt_tokens) {
250+ auto num_tokens_to_prefill_with =
251+ std::min (num_prompt_tokens - num_tokens_to_process, max_seq_len);
252+ std::vector<uint64_t > prompt_tokens_to_process (num_tokens_to_prefill_with);
253+ std::copy (
254+ prompt_tokens.begin () + num_tokens_to_process,
255+ prompt_tokens.begin () + num_tokens_to_process + num_tokens_to_prefill_with,
256+ prompt_tokens_to_process.begin ());
257+ auto prefill_res =
258+ text_prefiller_->prefill (prompt_tokens_to_process, pos);
259+ ET_CHECK_OK_OR_RETURN_ERROR (prefill_res.error ());
260+ cur_token = prefill_res.get ();
261+ num_tokens_to_process += num_tokens_to_prefill_with;
262+ }
245263 stats_.first_token_ms = llm::time_in_ms ();
246264 stats_.prompt_eval_end_ms = llm::time_in_ms ();
247- ET_CHECK_OK_OR_RETURN_ERROR (prefill_res.error ());
248- uint64_t cur_token = prefill_res.get ();
249265
250266 // print the first token from prefill. No prev_token so use cur_token for it.
251267 wrapped_callback (
0 commit comments