Skip to content

Commit 045a419

Browse files
authored
Revert "[llm] Add generate_from_pos API to LLM runner (#11570)"
This reverts commit be8ffd1.
1 parent be8ffd1 commit 045a419

File tree

4 files changed

+14
-134
lines changed

4 files changed

+14
-134
lines changed

extension/llm/runner/irunner.h

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -121,23 +121,6 @@ class ET_EXPERIMENTAL IRunner {
121121
std::function<void(const std::string&)> token_callback,
122122
std::function<void(const Stats&)> stats_callback) = 0;
123123

124-
/**
125-
* Generate text based on the provided prompt and generation config, from a
126-
* given position in KV cache.
127-
*
128-
* @param prompt The input prompt to generate from
129-
* @param start_pos The starting position in KV cache of the input
130-
* @param config Generation configuration parameters
131-
* @param token_callback Callback function called for each generated token
132-
* @param stats_callback Callback function for generation statistics
133-
* @return Error::Ok if successful, an error otherwise
134-
*/
135-
virtual runtime::Error generate_from_pos(
136-
const std::string& prompt,
137-
int64_t start_pos,
138-
const GenerationConfig& config,
139-
std::function<void(const std::string&)> token_callback,
140-
std::function<void(const Stats&)> stats_callback) = 0;
141124
/**
142125
* Stop the generation process.
143126
*/

extension/llm/runner/test/test_text_llm_runner.cpp

Lines changed: 0 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -322,58 +322,3 @@ TEST_F(RunnerTest, IsLoadedReturnsTrueWhenComponentsInitialized) {
322322
// Verify is_loaded returns true
323323
EXPECT_TRUE(runner.is_loaded());
324324
}
325-
326-
// Test that generate_from_pos() errors out when max_new_tokens is negative
327-
TEST_F(RunnerTest, GenerateFromPosErrorsWithNegativeMaxNewTokens) {
328-
// Create mock instances using helper functions
329-
auto tokenizer = createMockTokenizer();
330-
auto text_decoder_runner = createMockTextDecoderRunner();
331-
auto text_prefiller = createMockTextPrefiller(text_decoder_runner.get());
332-
333-
// Set up expectations for the tokenizer encode method
334-
EXPECT_CALL(*tokenizer, encode(_, _, _))
335-
.WillOnce(Return(::tokenizers::Result<std::vector<uint64_t>>(
336-
std::vector<uint64_t>{1, 2, 3})));
337-
338-
// Set up expectations for load methods
339-
EXPECT_CALL(*text_prefiller, is_loaded()).WillRepeatedly(Return(true));
340-
341-
std::unique_ptr<executorch::llm::Stats> stats =
342-
std::make_unique<executorch::llm::Stats>();
343-
// Create a real TextTokenGenerator
344-
auto text_token_generator = createTextTokenGenerator(
345-
tokenizer.get(), text_decoder_runner.get(), stats.get());
346-
347-
// Create a Runner with our mocked components
348-
TextLLMRunner runner(
349-
{
350-
{"enable_dynamic_shape", false},
351-
{"get_max_seq_len", 10},
352-
{"get_max_context_len", 10},
353-
{"use_kv_cache", true},
354-
},
355-
std::unique_ptr<::tokenizers::Tokenizer>(tokenizer.release()),
356-
std::make_unique<MockModule>(),
357-
std::move(text_decoder_runner),
358-
std::unique_ptr<::executorch::extension::llm::TextPrefiller>(
359-
text_prefiller.release()),
360-
std::move(text_token_generator),
361-
std::move(stats));
362-
363-
// Load
364-
runner.load();
365-
366-
// Set up the generation config with a negative max_new_tokens value
367-
GenerationConfig config;
368-
config.max_new_tokens = 5;
369-
config.echo = false;
370-
371-
// num_prompt_tokens = 3
372-
// max_context_len = 10
373-
// start_pos = 8, this should fail because 10 - 8 > 3, even though
374-
// config.max_new_tokens = 5 > 3, it's still a failure.
375-
Error err = runner.generate_from_pos("test prompt", 8, config);
376-
377-
// Verify that an InvalidArgument error is returned
378-
EXPECT_EQ(err, Error::InvalidArgument);
379-
}

extension/llm/runner/text_llm_runner.cpp

Lines changed: 12 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,8 @@ Error TextLLMRunner::load() {
7373
ET_LOG(Info, format, __VA_ARGS__); \
7474
}
7575

76-
Error TextLLMRunner::generate_from_pos(
76+
Error TextLLMRunner::generate(
7777
const std::string& prompt,
78-
int64_t start_pos,
7978
const GenerationConfig& config,
8079
std::function<void(const std::string&)> token_callback,
8180
std::function<void(const Stats&)> stats_callback) {
@@ -126,34 +125,20 @@ Error TextLLMRunner::generate_from_pos(
126125
std::vector<uint64_t> prompt_tokens = encode_res.get();
127126
int num_prompt_tokens = prompt_tokens.size();
128127

129-
// Reduce max_context_len by start_pos
130-
int64_t max_context_len = metadata_.at(kMaxContextLen) - start_pos;
131128
ET_CHECK_MSG(num_prompt_tokens >= 1, "Expected at least 1 prompt token");
132129
ET_CHECK_MSG(
133-
num_prompt_tokens < max_context_len,
134-
"num_prompt_tokens %d >= max_context_len %" PRId64
130+
num_prompt_tokens < metadata_.at(kMaxContextLen),
131+
"num_prompt_tokens %d >= max_seq_len_ %" PRId64
135132
", Max seq length exceeded - please increase max seq len value in your export script",
136133
num_prompt_tokens,
137-
max_context_len);
138-
139-
// Determine max_new_tokens using the GenerationConfig's resolve method,
140-
// then subtract start_pos for max_new_tokens.
141-
int max_new_tokens =
142-
config.resolve_max_new_tokens(max_context_len, num_prompt_tokens);
143-
144-
ET_LOG(
145-
Info,
146-
"Max new tokens resolved: %d, given start_pos %" PRId64
147-
", num_prompt_tokens %zu, max_context_len %" PRId64,
148-
max_new_tokens,
149-
start_pos,
150-
prompt_tokens.size(),
151-
max_context_len);
152-
ET_CHECK_OR_RETURN_ERROR(
153-
max_new_tokens > 0,
154-
InvalidArgument,
155-
"Max new tokens %d is less than or equal to 0",
156-
max_new_tokens);
134+
metadata_.at(kMaxContextLen));
135+
136+
// Determine max_new_tokens using the GenerationConfig's resolve method
137+
int max_new_tokens = config.resolve_max_new_tokens(
138+
metadata_.at(kMaxContextLen), num_prompt_tokens);
139+
140+
ET_LOG(Info, "Max new tokens resolved: %d", max_new_tokens);
141+
157142
// Prefill first
158143
// Here feed all tokens to the model and get the next predicted token
159144
// after the prompt. After that we will enter generate loop.
@@ -162,7 +147,7 @@ Error TextLLMRunner::generate_from_pos(
162147
if (config.echo) {
163148
wrapped_callback(prompt);
164149
}
165-
int64_t pos = start_pos;
150+
int64_t pos = 0;
166151
auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos);
167152
ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error());
168153
uint64_t cur_token = prefill_res.get();
@@ -216,13 +201,6 @@ Error TextLLMRunner::generate_from_pos(
216201

217202
return Error::Ok;
218203
}
219-
Error TextLLMRunner::generate(
220-
const std::string& prompt,
221-
const GenerationConfig& config,
222-
std::function<void(const std::string&)> token_callback,
223-
std::function<void(const Stats&)> stats_callback) {
224-
return generate_from_pos(prompt, 0, config, token_callback, stats_callback);
225-
}
226204

227205
Error TextLLMRunner::warmup(const std::string& prompt, int32_t max_new_tokens) {
228206
// Create a GenerationConfig for warmup

extension/llm/runner/text_llm_runner.h

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,8 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner {
7878
* @brief Generates text based on the provided prompt
7979
*
8080
* This method performs text generation using the loaded model. It processes
81-
* the input prompt, runs the model in prefill and decode phases until max
82-
* tokens to generate is reached or eos token is generated, then returns
83-
* generated text and perf stats through callbacks.
81+
* the input prompt, runs the model in prefill and decode phases, and returns
82+
* generated text through callbacks.
8483
*
8584
* @param prompt The input text to generate from
8685
* @param config Configuration parameters for text generation (e.g.,
@@ -95,31 +94,6 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner {
9594
const GenerationConfig& config,
9695
std::function<void(const std::string&)> token_callback = {},
9796
std::function<void(const Stats&)> stats_callback = {}) override;
98-
99-
/**
100-
* @brief Generates text based on the provided prompt and start position
101-
*
102-
* This method performs text generation using the loaded model. It processes
103-
* the input prompt, runs the model in prefill and decode phases using the
104-
* start position until max tokens to generate is reached or eos token is
105-
* generated, then returns generated text and perf stats through callbacks.
106-
*
107-
* @param prompt The input text to generate from
108-
* @param start_pos The starting position in KV cache of the input
109-
* @param config Configuration parameters for text generation (e.g.,
110-
* max_new_tokens, temperature)
111-
* @param token_callback Function called for each generated token with the
112-
* decoded text
113-
* @param stats_callback Function called with performance statistics
114-
* @return ::executorch::runtime::Error Success or error status
115-
*/
116-
::executorch::runtime::Error generate_from_pos(
117-
const std::string& prompt,
118-
int64_t start_pos,
119-
const GenerationConfig& config,
120-
std::function<void(const std::string&)> token_callback = {},
121-
std::function<void(const Stats&)> stats_callback = {}) override;
122-
12397
/**
12498
* @brief Warms up the model with a sample prompt
12599
*

0 commit comments

Comments
 (0)