Skip to content

Commit 34b6a2d

Browse files
authored
text_llm_runner manage start_pos internally (#14028)
Currently, multimodal runner can manage KV cache position internally, and avoid exposing start_pos to users. text_llm_runner should do the same
1 parent 9ce07da commit 34b6a2d

File tree

4 files changed

+48
-85
lines changed

4 files changed

+48
-85
lines changed

extension/llm/runner/irunner.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,9 @@ class ET_EXPERIMENTAL IRunner {
130130
* given position in KV cache.
131131
*
132132
* @param prompt The input prompt to generate from
133-
* @param start_pos The starting position in KV cache of the input
133+
* @param start_pos The starting position in KV cache of the input. Note:
134+
* Depending on the actual implementation, a runner may manage the position
135+
* internally, and this may not be respected.
134136
* @param config Generation configuration parameters
135137
* @param token_callback Callback function called for each generated token
136138
* @param stats_callback Callback function for generation statistics
@@ -146,6 +148,16 @@ class ET_EXPERIMENTAL IRunner {
146148
* Stop the generation process.
147149
*/
148150
virtual void stop() = 0;
151+
/**
152+
* Force remove prefilled tokens and reset KV cache start position
153+
*
154+
* For some existing runners, overriding this method is not needed because
155+
* start_pos is passed as an argument to generate_from_pos.
156+
*
157+
* This method removes the prefilled tokens from the KV cache and resets the
158+
* start position to 0.
159+
*/
160+
virtual void reset() {};
149161
};
150162

151163
} // namespace llm

extension/llm/runner/test/test_text_llm_runner.cpp

Lines changed: 0 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -346,64 +346,4 @@ TEST_F(RunnerTest, IsLoadedReturnsTrueWhenComponentsInitialized) {
346346
EXPECT_TRUE(runner.is_loaded());
347347
}
348348

349-
// Test that generate_from_pos() errors out when max_new_tokens is negative
350-
TEST_F(RunnerTest, GenerateFromPosErrorsWithNegativeMaxNewTokens) {
351-
// Create mock instances using helper functions
352-
auto tokenizer = createMockTokenizer();
353-
auto text_decoder_runner = createMockTextDecoderRunner();
354-
auto text_prefiller = createMockTextPrefiller(text_decoder_runner.get());
355-
356-
// Set up expectations for the tokenizer encode method
357-
ON_CALL(*tokenizer, encode(_, _, _))
358-
.WillByDefault([&](const std::string&, int8_t, int8_t) {
359-
return ::tokenizers::Result<std::vector<uint64_t>>(
360-
std::vector<uint64_t>{1, 2, 3});
361-
});
362-
363-
// Set up expectations for load methods
364-
ON_CALL(*text_prefiller, is_loaded()).WillByDefault(Return(true));
365-
366-
std::unique_ptr<executorch::llm::Stats> stats =
367-
std::make_unique<executorch::llm::Stats>();
368-
// Create a real TextTokenGenerator
369-
auto text_token_generator = createTextTokenGenerator(
370-
tokenizer.get(), text_decoder_runner.get(), stats.get());
371-
372-
// Create a Runner with our mocked components
373-
auto module = std::make_unique<MockModule>();
374-
auto io_manager =
375-
std::make_unique<executorch::extension::llm::IOManager>(*module);
376-
TextLLMRunner runner(
377-
{
378-
{"enable_dynamic_shape", false},
379-
{"get_max_seq_len", 10},
380-
{"get_max_context_len", 10},
381-
{"use_kv_cache", true},
382-
},
383-
std::unique_ptr<::tokenizers::Tokenizer>(tokenizer.release()),
384-
std::move(module),
385-
std::move(text_decoder_runner),
386-
std::unique_ptr<::executorch::extension::llm::TextPrefiller>(
387-
text_prefiller.release()),
388-
std::move(io_manager),
389-
std::move(text_token_generator),
390-
std::move(stats));
391-
392-
// Load
393-
runner.load();
394-
395-
// Set up the generation config with a negative max_new_tokens value
396-
GenerationConfig config;
397-
config.max_new_tokens = 5;
398-
config.echo = false;
399-
400-
// num_prompt_tokens = 3
401-
// max_context_len = 10
402-
// start_pos = 8, this should fail because 10 - 8 > 3, even though
403-
// config.max_new_tokens = 5 > 3, it's still a failure.
404-
Error err = runner.generate_from_pos("test prompt", 8, config);
405-
406-
// Verify that an InvalidArgument error is returned
407-
EXPECT_EQ(err, Error::InvalidArgument);
408-
}
409349
} // namespace

extension/llm/runner/text_llm_runner.cpp

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ Error TextLLMRunner::load() {
7272

7373
Error TextLLMRunner::generate_from_pos(
7474
const std::string& prompt,
75-
int64_t start_pos,
75+
ET_UNUSED int64_t start_pos,
7676
const GenerationConfig& config,
7777
std::function<void(const std::string&)> token_callback,
7878
std::function<void(const Stats&)> stats_callback) {
@@ -123,8 +123,8 @@ Error TextLLMRunner::generate_from_pos(
123123
std::vector<uint64_t> prompt_tokens = encode_res.get();
124124
int num_prompt_tokens = prompt_tokens.size();
125125

126-
// Reduce max_context_len by start_pos
127-
int64_t max_context_len = metadata_.at(kMaxContextLen) - start_pos;
126+
// Reduce max_context_len by pos_
127+
int64_t max_context_len = metadata_.at(kMaxContextLen) - pos_;
128128
ET_CHECK_OR_RETURN_ERROR(
129129
num_prompt_tokens >= 1,
130130
InvalidArgument,
@@ -138,16 +138,16 @@ Error TextLLMRunner::generate_from_pos(
138138
max_context_len);
139139

140140
// Determine max_new_tokens using the GenerationConfig's resolve method,
141-
// then subtract start_pos for max_new_tokens.
141+
// then subtract pos_ for max_new_tokens.
142142
int max_new_tokens =
143143
config.resolve_max_new_tokens(max_context_len, num_prompt_tokens);
144144

145145
ET_LOG(
146146
Info,
147-
"Max new tokens resolved: %d, given start_pos %" PRId64
147+
"Max new tokens resolved: %d, given pos_ %" PRId64
148148
", num_prompt_tokens %zu, max_context_len %" PRId64,
149149
max_new_tokens,
150-
start_pos,
150+
pos_,
151151
prompt_tokens.size(),
152152
max_context_len);
153153
ET_CHECK_OR_RETURN_ERROR(
@@ -163,8 +163,7 @@ Error TextLLMRunner::generate_from_pos(
163163
if (config.echo) {
164164
wrapped_callback(prompt);
165165
}
166-
int64_t pos = start_pos;
167-
auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos);
166+
auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos_);
168167
ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error());
169168
uint64_t cur_token = prefill_res.get();
170169
stats_->first_token_ms = time_in_ms();
@@ -217,11 +216,13 @@ Error TextLLMRunner::generate_from_pos(
217216

218217
return Error::Ok;
219218
}
219+
220220
Error TextLLMRunner::generate(
221221
const std::string& prompt,
222222
const GenerationConfig& config,
223223
std::function<void(const std::string&)> token_callback,
224224
std::function<void(const Stats&)> stats_callback) {
225+
pos_ = 0;
225226
return generate_from_pos(prompt, 0, config, token_callback, stats_callback);
226227
}
227228

@@ -246,4 +247,9 @@ void TextLLMRunner::stop() {
246247
}
247248
}
248249

250+
void TextLLMRunner::reset() {
251+
stats_->reset();
252+
pos_ = 0;
253+
}
254+
249255
} // namespace executorch::extension::llm

extension/llm/runner/text_llm_runner.h

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -102,25 +102,20 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner {
102102
std::function<void(const Stats&)> stats_callback = {}) override;
103103

104104
/**
105-
* @brief Generates text based on the provided prompt and start position
105+
* Generate text based on the provided prompt and generation config, from a
106+
* given position in KV cache.
106107
*
107-
* This method performs text generation using the loaded model. It processes
108-
* the input prompt, runs the model in prefill and decode phases using the
109-
* start position until max tokens to generate is reached or eos token is
110-
* generated, then returns generated text and perf stats through callbacks.
111-
*
112-
* @param prompt The input text to generate from
113-
* @param start_pos The starting position in KV cache of the input
114-
* @param config Configuration parameters for text generation (e.g.,
115-
* max_new_tokens, temperature)
116-
* @param token_callback Function called for each generated token with the
117-
* decoded text
118-
* @param stats_callback Function called with performance statistics
119-
* @return ::executorch::runtime::Error Success or error status
108+
* @param prompt The input prompt to generate from
109+
* @param start_pos [Unused] The starting position in KV cache of the input,
110+
* ignored because the runner manages the position internally.
111+
* @param config Generation configuration parameters
112+
* @param token_callback Callback function called for each generated token
113+
* @param stats_callback Callback function for generation statistics
114+
* @return Error::Ok if successful, an error otherwise
120115
*/
121-
::executorch::runtime::Error generate_from_pos(
116+
ET_DEPRECATED runtime::Error generate_from_pos(
122117
const std::string& prompt,
123-
int64_t start_pos,
118+
ET_UNUSED int64_t start_pos,
124119
const GenerationConfig& config,
125120
std::function<void(const std::string&)> token_callback = {},
126121
std::function<void(const Stats&)> stats_callback = {}) override;
@@ -138,6 +133,13 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner {
138133
::executorch::runtime::Error warmup(
139134
const std::string& prompt,
140135
int32_t max_new_tokens);
136+
/**
137+
* @brief Remove prefilled tokens and reset start position, and stats.
138+
*
139+
* This method removes the prefilled tokens from the KV cache and resets the
140+
* start position to 0. It also clears the stats for previous runs.
141+
*/
142+
void reset() override;
141143
/**
142144
* @brief Stops the ongoing text generation process
143145
*
@@ -169,6 +171,9 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner {
169171
// temperature.
170172
// Deprecated, we should rely on the temperature in GenerationConfig instead.
171173
float temperature_ = -1.0f;
174+
175+
// The position in KV cache of the input, starting from 0.
176+
int64_t pos_ = 0;
172177
};
173178

174179
} // namespace executorch::extension::llm

0 commit comments

Comments
 (0)