1111
1212#include  < executorch/examples/models/llama/runner/runner.h> 
1313
14+ #include  < algorithm> 
1415#include  < ctime> 
1516
1617#include  < executorch/extension/llm/runner/util.h> 
@@ -221,11 +222,11 @@ Error Runner::generate(
221222
222223  ET_CHECK_MSG (num_prompt_tokens >= 1 , " Expected at least 1 prompt token" 
223224  ET_CHECK_MSG (
224-       num_prompt_tokens < metadata_.at (kMaxSeqLen ),
225+       num_prompt_tokens < metadata_.at (kMaxContextLen ),
225226      " num_prompt_tokens %d >= max_seq_len_ %" 
226227      " , Max seq length exceeded - please increase max seq len value in .../llama2/model.py" 
227228      num_prompt_tokens,
228-       metadata_.at (kMaxSeqLen ));
229+       metadata_.at (kMaxContextLen ));
229230  ET_CHECK_MSG (
230231      num_prompt_tokens < seq_len,
231232      " num_prompt_tokens %d >= seq_len %d, Sequence length exceeded - please increase the seq_len value passed to generate()" 
@@ -241,11 +242,26 @@ Error Runner::generate(
241242    wrapped_callback (prompt);
242243  }
243244  int64_t  pos = 0 ;
244-   auto  prefill_res = text_prefiller_->prefill (prompt_tokens, pos);
245+   uint64_t  cur_token;
246+   int  max_seq_len = metadata_.at (kMaxSeqLen ) -
247+       1 ; //  -1 because for some reason tracing results in this upperbound
248+   int  num_tokens_to_process = 0 ;
249+   while  (num_tokens_to_process < num_prompt_tokens) {
250+     auto  num_tokens_to_prefill_with =
251+         std::min (num_prompt_tokens - num_tokens_to_process, max_seq_len);
252+     std::vector<uint64_t > prompt_tokens_to_process (num_tokens_to_prefill_with);
253+     std::copy (
254+         prompt_tokens.begin () + num_tokens_to_process,
255+         prompt_tokens.begin () + num_tokens_to_process +
256+             num_tokens_to_prefill_with,
257+         prompt_tokens_to_process.begin ());
258+     auto  prefill_res = text_prefiller_->prefill (prompt_tokens_to_process, pos);
259+     ET_CHECK_OK_OR_RETURN_ERROR (prefill_res.error ());
260+     cur_token = prefill_res.get ();
261+     num_tokens_to_process += num_tokens_to_prefill_with;
262+   }
245263  stats_.first_token_ms  = llm::time_in_ms ();
246264  stats_.prompt_eval_end_ms  = llm::time_in_ms ();
247-   ET_CHECK_OK_OR_RETURN_ERROR (prefill_res.error ());
248-   uint64_t  cur_token = prefill_res.get ();
249265
250266  //  print the first token from prefill. No prev_token so use cur_token for it.
251267  wrapped_callback (
0 commit comments