Skip to content

Commit b5dfef9

Browse files
committed
Update on "[ExecuTorch][Llama] Change runner to enable chunked prefill"
This diff adds code to chunk prompt longer than max_seq_len to enable prefill of larger context Differential Revision: [D71833061](https://our.internmc.facebook.com/intern/diff/D71833061/) [ghstack-poisoned]
1 parent d6c9e45 commit b5dfef9

File tree

3 files changed

+60
-22
lines changed

3 files changed

+60
-22
lines changed

examples/models/llama/runner/runner.cpp

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,8 @@ Error Runner::load() {
141141
text_prefiller_ = std::make_unique<llm::TextPrefiller>(
142142
text_decoder_runner_.get(),
143143
metadata_.at(kUseKVCache),
144-
metadata_.at(kEnableDynamicShape));
144+
metadata_.at(kEnableDynamicShape),
145+
metadata_.at(kMaxSeqLen));
145146

146147
text_token_generator_ = std::make_unique<llm::TextTokenGenerator>(
147148
tokenizer_.get(),
@@ -242,24 +243,9 @@ Error Runner::generate(
242243
wrapped_callback(prompt);
243244
}
244245
int64_t pos = 0;
245-
uint64_t cur_token;
246-
int max_seq_len = metadata_.at(kMaxSeqLen) -
247-
1; // -1 because for some reason tracing results in this upperbound
248-
int num_tokens_to_process = 0;
249-
while (num_tokens_to_process < num_prompt_tokens) {
250-
auto num_tokens_to_prefill_with =
251-
std::min(num_prompt_tokens - num_tokens_to_process, max_seq_len);
252-
std::vector<uint64_t> prompt_tokens_to_process(num_tokens_to_prefill_with);
253-
std::copy(
254-
prompt_tokens.begin() + num_tokens_to_process,
255-
prompt_tokens.begin() + num_tokens_to_process +
256-
num_tokens_to_prefill_with,
257-
prompt_tokens_to_process.begin());
258-
auto prefill_res = text_prefiller_->prefill(prompt_tokens_to_process, pos);
259-
ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error());
260-
cur_token = prefill_res.get();
261-
num_tokens_to_process += num_tokens_to_prefill_with;
262-
}
246+
auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos);
247+
ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error());
248+
uint64_t cur_token = prefill_res.get();
263249
stats_.first_token_ms = llm::time_in_ms();
264250
stats_.prompt_eval_end_ms = llm::time_in_ms();
265251

extension/llm/runner/text_prefiller.cpp

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
// LLM.
1111

1212
#include <executorch/extension/llm/runner/text_prefiller.h>
13+
#include <algorithm>
1314

1415
namespace executorch {
1516
namespace extension {
@@ -18,10 +19,12 @@ namespace llm {
1819
TextPrefiller::TextPrefiller(
1920
TextDecoderRunner* text_decoder_runner,
2021
bool use_kv_cache,
21-
bool enable_parallel_prefill)
22+
bool enable_parallel_prefill,
23+
int64_t max_seq_len)
2224
: text_decoder_runner_(text_decoder_runner),
2325
use_kv_cache_(use_kv_cache),
24-
enable_parallel_prefill_(enable_parallel_prefill) {}
26+
enable_parallel_prefill_(enable_parallel_prefill),
27+
max_seq_len_(max_seq_len > 0 ? max_seq_len - 1 : 127) {} // -1 because for some reason tracing results in this upperbound
2528

2629
::executorch::runtime::Result<uint64_t> TextPrefiller::prefill(
2730
std::vector<uint64_t>& prompt_tokens,
@@ -30,6 +33,43 @@ ::executorch::runtime::Result<uint64_t> TextPrefiller::prefill(
3033
if (!text_decoder_runner_->is_method_loaded()) {
3134
ET_CHECK_OK_OR_RETURN_ERROR(text_decoder_runner_->load());
3235
}
36+
37+
// Check if we need to chunk the prompt tokens
38+
int32_t num_prompt_tokens = prompt_tokens.size();
39+
40+
// If prompt tokens exceed max_seq_len_, we need to chunk them
41+
if (num_prompt_tokens > max_seq_len_) {
42+
uint64_t cur_token = 0;
43+
int num_tokens_to_process = 0;
44+
45+
while (num_tokens_to_process < num_prompt_tokens) {
46+
auto num_tokens_to_prefill_with =
47+
std::min<int>(num_prompt_tokens - num_tokens_to_process, max_seq_len_);
48+
49+
std::vector<uint64_t> prompt_tokens_to_process(num_tokens_to_prefill_with);
50+
std::copy(
51+
prompt_tokens.begin() + num_tokens_to_process,
52+
prompt_tokens.begin() + num_tokens_to_process + num_tokens_to_prefill_with,
53+
prompt_tokens_to_process.begin());
54+
55+
// Process this chunk
56+
auto chunk_result = prefillChunk(prompt_tokens_to_process, start_pos);
57+
ET_CHECK_OK_OR_RETURN_ERROR(chunk_result.error());
58+
cur_token = chunk_result.get();
59+
60+
num_tokens_to_process += num_tokens_to_prefill_with;
61+
}
62+
63+
return cur_token;
64+
} else {
65+
// If prompt tokens don't exceed max_seq_len_, process them directly
66+
return prefillChunk(prompt_tokens, start_pos);
67+
}
68+
}
69+
70+
::executorch::runtime::Result<uint64_t> TextPrefiller::prefillChunk(
71+
std::vector<uint64_t>& prompt_tokens,
72+
int64_t& start_pos) {
3373
// enable_parallel_prefill_ maybe set even when not using kv cache
3474
// When kv cache is not used, start pos is ignored
3575
int32_t num_prompt_tokens = prompt_tokens.size();

extension/llm/runner/text_prefiller.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ class ET_EXPERIMENTAL TextPrefiller {
2222
TextPrefiller(
2323
TextDecoderRunner* text_decoder_runner,
2424
bool use_kv_cache_,
25-
bool enable_parallel_prefill);
25+
bool enable_parallel_prefill,
26+
int64_t max_seq_len = 128);
2627
/**
2728
* Prefill an LLM Module with the given text input.
2829
* @param prompt_tokens The text prompt tokens to the LLM Module. Encoded by
@@ -35,10 +36,21 @@ class ET_EXPERIMENTAL TextPrefiller {
3536
std::vector<uint64_t>& prompt_tokens,
3637
int64_t& start_pos);
3738

39+
/**
40+
* Helper method to prefill a chunk of tokens.
41+
* @param prompt_tokens The chunk of text prompt tokens to process.
42+
* @param start_pos The starting position in KV cache of the input in the LLM Module.
43+
* @return The next token of the LLM Module after prefilling this chunk.
44+
*/
45+
::executorch::runtime::Result<uint64_t> prefillChunk(
46+
std::vector<uint64_t>& prompt_tokens,
47+
int64_t& start_pos);
48+
3849
private:
3950
TextDecoderRunner* text_decoder_runner_;
4051
bool use_kv_cache_;
4152
bool enable_parallel_prefill_;
53+
int64_t max_seq_len_;
4254
};
4355

4456
} // namespace llm

0 commit comments

Comments
 (0)