Skip to content

Commit 51556f0

Browse files
committed
[ExecuTorch][Llama] Change runner to enable chunked prefill
Pull Request resolved: #9785 This diff adds code to chunk prompt longer than max_seq_len to enable prefill of larger context ghstack-source-id: 275281050 Differential Revision: [D71833061](https://our.internmc.facebook.com/intern/diff/D71833061/)
1 parent 2aa7748 commit 51556f0

File tree

4 files changed

+64
-9
lines changed

4 files changed

+64
-9
lines changed

examples/models/llama/runner/runner.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include <executorch/examples/models/llama/runner/runner.h>
1313

14+
#include <algorithm>
1415
#include <ctime>
1516

1617
#include <executorch/extension/llm/runner/util.h>
@@ -140,7 +141,8 @@ Error Runner::load() {
140141
text_prefiller_ = std::make_unique<llm::TextPrefiller>(
141142
text_decoder_runner_.get(),
142143
metadata_.at(kUseKVCache),
143-
metadata_.at(kEnableDynamicShape));
144+
metadata_.at(kEnableDynamicShape),
145+
metadata_.at(kMaxSeqLen));
144146

145147
text_token_generator_ = std::make_unique<llm::TextTokenGenerator>(
146148
tokenizer_.get(),
@@ -221,11 +223,11 @@ Error Runner::generate(
221223

222224
ET_CHECK_MSG(num_prompt_tokens >= 1, "Expected at least 1 prompt token");
223225
ET_CHECK_MSG(
224-
num_prompt_tokens < metadata_.at(kMaxSeqLen),
226+
num_prompt_tokens < metadata_.at(kMaxContextLen),
225227
"num_prompt_tokens %d >= max_seq_len_ %" PRId64
226228
", Max seq length exceeded - please increase max seq len value in .../llama2/model.py",
227229
num_prompt_tokens,
228-
metadata_.at(kMaxSeqLen));
230+
metadata_.at(kMaxContextLen));
229231
ET_CHECK_MSG(
230232
num_prompt_tokens < seq_len,
231233
"num_prompt_tokens %d >= seq_len %d, Sequence length exceeded - please increase the seq_len value passed to generate()",
@@ -242,10 +244,10 @@ Error Runner::generate(
242244
}
243245
int64_t pos = 0;
244246
auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos);
245-
stats_.first_token_ms = llm::time_in_ms();
246-
stats_.prompt_eval_end_ms = llm::time_in_ms();
247247
ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error());
248248
uint64_t cur_token = prefill_res.get();
249+
stats_.first_token_ms = llm::time_in_ms();
250+
stats_.prompt_eval_end_ms = llm::time_in_ms();
249251

250252
// print the first token from prefill. No prev_token so use cur_token for it.
251253
wrapped_callback(

examples/models/llava/runner/llava_runner.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ Error LlavaRunner::load() {
5555
text_prefiller_ = std::make_unique<llm::TextPrefiller>(
5656
text_decoder_runner_.get(),
5757
/*use_kv_cache=*/true,
58-
/*enable_parallel_prefill=*/true);
58+
/*enable_parallel_prefill=*/true,
59+
/*max_seq_len=*/128);
5960

6061
// Load the image prefiller
6162
image_prefiller_ = std::make_unique<LlavaImagePrefiller>(module_.get());

extension/llm/runner/text_prefiller.cpp

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
// LLM.
1111

1212
#include <executorch/extension/llm/runner/text_prefiller.h>
13+
#include <algorithm>
1314

1415
namespace executorch {
1516
namespace extension {
@@ -18,10 +19,12 @@ namespace llm {
1819
TextPrefiller::TextPrefiller(
1920
TextDecoderRunner* text_decoder_runner,
2021
bool use_kv_cache,
21-
bool enable_parallel_prefill)
22+
bool enable_parallel_prefill,
23+
int64_t max_seq_len)
2224
: text_decoder_runner_(text_decoder_runner),
2325
use_kv_cache_(use_kv_cache),
24-
enable_parallel_prefill_(enable_parallel_prefill) {}
26+
enable_parallel_prefill_(enable_parallel_prefill),
27+
max_seq_len_(max_seq_len > 0 ? max_seq_len - 1 : 127) {} // -1 because for some reason tracing results in this upperbound
2528

2629
::executorch::runtime::Result<uint64_t> TextPrefiller::prefill(
2730
std::vector<uint64_t>& prompt_tokens,
@@ -30,6 +33,43 @@ ::executorch::runtime::Result<uint64_t> TextPrefiller::prefill(
3033
if (!text_decoder_runner_->is_method_loaded()) {
3134
ET_CHECK_OK_OR_RETURN_ERROR(text_decoder_runner_->load());
3235
}
36+
37+
// Check if we need to chunk the prompt tokens
38+
int32_t num_prompt_tokens = prompt_tokens.size();
39+
40+
// If prompt tokens exceed max_seq_len_, we need to chunk them
41+
if (num_prompt_tokens > max_seq_len_) {
42+
uint64_t cur_token = 0;
43+
int num_tokens_to_process = 0;
44+
45+
while (num_tokens_to_process < num_prompt_tokens) {
46+
auto num_tokens_to_prefill_with =
47+
std::min<int>(num_prompt_tokens - num_tokens_to_process, max_seq_len_);
48+
49+
std::vector<uint64_t> prompt_tokens_to_process(num_tokens_to_prefill_with);
50+
std::copy(
51+
prompt_tokens.begin() + num_tokens_to_process,
52+
prompt_tokens.begin() + num_tokens_to_process + num_tokens_to_prefill_with,
53+
prompt_tokens_to_process.begin());
54+
55+
// Process this chunk
56+
auto chunk_result = prefillChunk(prompt_tokens_to_process, start_pos);
57+
ET_CHECK_OK_OR_RETURN_ERROR(chunk_result.error());
58+
cur_token = chunk_result.get();
59+
60+
num_tokens_to_process += num_tokens_to_prefill_with;
61+
}
62+
63+
return cur_token;
64+
} else {
65+
// If prompt tokens don't exceed max_seq_len_, process them directly
66+
return prefillChunk(prompt_tokens, start_pos);
67+
}
68+
}
69+
70+
::executorch::runtime::Result<uint64_t> TextPrefiller::prefillChunk(
71+
std::vector<uint64_t>& prompt_tokens,
72+
int64_t& start_pos) {
3373
// enable_parallel_prefill_ maybe set even when not using kv cache
3474
// When kv cache is not used, start pos is ignored
3575
int32_t num_prompt_tokens = prompt_tokens.size();

extension/llm/runner/text_prefiller.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ class ET_EXPERIMENTAL TextPrefiller {
2222
TextPrefiller(
2323
TextDecoderRunner* text_decoder_runner,
2424
bool use_kv_cache_,
25-
bool enable_parallel_prefill);
25+
bool enable_parallel_prefill,
26+
int64_t max_seq_len = 128);
2627
/**
2728
* Prefill an LLM Module with the given text input.
2829
* @param prompt_tokens The text prompt tokens to the LLM Module. Encoded by
@@ -35,10 +36,21 @@ class ET_EXPERIMENTAL TextPrefiller {
3536
std::vector<uint64_t>& prompt_tokens,
3637
int64_t& start_pos);
3738

39+
/**
40+
* Helper method to prefill a chunk of tokens.
41+
* @param prompt_tokens The chunk of text prompt tokens to process.
42+
* @param start_pos The starting position in KV cache of the input in the LLM Module.
43+
* @return The next token of the LLM Module after prefilling this chunk.
44+
*/
45+
::executorch::runtime::Result<uint64_t> prefillChunk(
46+
std::vector<uint64_t>& prompt_tokens,
47+
int64_t& start_pos);
48+
3849
private:
3950
TextDecoderRunner* text_decoder_runner_;
4051
bool use_kv_cache_;
4152
bool enable_parallel_prefill_;
53+
int64_t max_seq_len_;
4254
};
4355

4456
} // namespace llm

0 commit comments

Comments
 (0)