Skip to content

Commit e6736a6

Browse files
Fix text prefiller start pos being updated twice (#15531)
In `text_prefiller.cpp` `prefill()` we are updating the `start_pos` inside `prefillChunk()` as well as outside. This PR fixes this issue. Also updating llama `main.cpp` to take `max_new_tokens`. Co-authored-by: Mengwei Liu <[email protected]>
1 parent 69dd7b9 commit e6736a6

File tree

4 files changed

+187
-100
lines changed

4 files changed

+187
-100
lines changed

examples/models/llama/main.cpp

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,12 @@ DEFINE_double(
3535
DEFINE_int32(
3636
seq_len,
3737
128,
38-
"Total number of tokens to generate (prompt + output). Defaults to max_seq_len. If the number of input tokens + seq_len > max_seq_len, the output will be truncated to max_seq_len tokens.");
38+
"DEPRECATED: Please use max_seq_len instead. Total number of tokens to generate (prompt + output). Defaults to max_seq_len. If the number of input tokens + seq_len > max_seq_len, the output will be truncated to max_seq_len tokens.");
39+
40+
DEFINE_int32(
41+
max_new_tokens,
42+
-1,
43+
"Total number of tokens to generate, excluding the prompt, will be capped by max_seq_len - # prompt tokens.");
3944

4045
DEFINE_int32(
4146
cpu_threads,
@@ -100,20 +105,33 @@ int32_t main(int32_t argc, char** argv) {
100105
}
101106

102107
if (warmup) {
103-
auto error = runner->warmup(prompt, /*max_new_tokens=*/seq_len);
108+
int32_t warmup_max_new_tokens =
109+
FLAGS_max_new_tokens != -1 ? FLAGS_max_new_tokens : seq_len;
110+
auto error =
111+
runner->warmup(prompt, /*max_new_tokens=*/warmup_max_new_tokens);
104112
if (error != executorch::runtime::Error::Ok) {
105113
ET_LOG(Error, "Failed to warmup llama runner");
106114
return 1;
107115
}
108-
// reset kv cache pos to 0
109-
runner->reset();
110116
}
111117
// generate
112118
executorch::extension::llm::GenerationConfig config{
113-
.seq_len = seq_len, .temperature = temperature};
119+
.temperature = temperature};
120+
121+
if (FLAGS_max_new_tokens != -1) {
122+
config.max_new_tokens = FLAGS_max_new_tokens;
123+
} else {
124+
ET_LOG(
125+
Info,
126+
"max_new_tokens not provided, falling back to seq_len=%d. "
127+
"Consider using --max_new_tokens instead of --seq_len for specifying generation length.",
128+
seq_len);
129+
config.seq_len = seq_len;
130+
}
131+
114132
auto error = runner->generate(prompt, config);
115133
if (error != executorch::runtime::Error::Ok) {
116-
ET_LOG(Error, "Failed to warmup llama runner");
134+
ET_LOG(Error, "Failed to run llama runner");
117135
return 1;
118136
}
119137

extension/llm/runner/test/test_text_prefiller.cpp

Lines changed: 162 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -138,113 +138,68 @@ TEST_F(TextPrefillerTest, PrefillCallsPrefillChunkOnceWhenPromptFits) {
138138
TEST_F(
139139
TextPrefillerTest,
140140
PrefillCallsPrefillChunkMultipleTimesWhenPromptExceedsMaxLen) {
141-
// Create a spy TextPrefiller with max_seq_len = 3
141+
// Create a real TextPrefiller with max_seq_len = 3 and parallel prefill
142142
const int64_t max_seq_len = 3;
143-
auto prefiller = createMockTextPrefiller(max_seq_len);
143+
auto prefiller = createTextPrefiller(max_seq_len, true, true);
144144

145145
// Create prompt tokens with size > max_seq_len
146146
std::vector<uint64_t> prompt_tokens = {1, 2, 3, 4, 5, 6, 7, 8};
147147
int64_t start_pos = 0;
148148

149-
// Set up expectations for prefill_chunk calls
150-
{
151-
InSequence seq; // Ensure calls happen in the expected order
152-
153-
// First chunk: tokens [1, 2, 3]
154-
EXPECT_CALL(*prefiller, prefill_chunk(_, _))
155-
.WillOnce([&](std::vector<uint64_t>& tokens, int64_t& pos) {
156-
EXPECT_EQ(tokens.size(), 3);
157-
EXPECT_EQ(tokens[0], 1);
158-
EXPECT_EQ(tokens[1], 2);
159-
EXPECT_EQ(tokens[2], 3);
160-
EXPECT_EQ(pos, 0);
161-
return Result<uint64_t>(10);
162-
});
163-
164-
// Second chunk: tokens [4, 5, 6]
165-
EXPECT_CALL(*prefiller, prefill_chunk(_, _))
166-
.WillOnce([&](std::vector<uint64_t>& tokens, int64_t& pos) {
167-
EXPECT_EQ(tokens.size(), 3);
168-
EXPECT_EQ(tokens[0], 4);
169-
EXPECT_EQ(tokens[1], 5);
170-
EXPECT_EQ(tokens[2], 6);
171-
EXPECT_EQ(pos, 3);
172-
return Result<uint64_t>(20);
173-
});
174-
175-
// Third chunk: tokens [7, 8]
176-
EXPECT_CALL(*prefiller, prefill_chunk(_, _))
177-
.WillOnce([&](std::vector<uint64_t>& tokens, int64_t& pos) {
178-
EXPECT_EQ(tokens.size(), 2);
179-
EXPECT_EQ(tokens[0], 7);
180-
EXPECT_EQ(tokens[1], 8);
181-
EXPECT_EQ(pos, 6);
182-
return Result<uint64_t>(30);
183-
});
184-
}
149+
// Track all tokens and positions passed to text_decoder_runner step
150+
struct StepCall {
151+
std::vector<uint64_t> tokens;
152+
int64_t pos;
153+
};
154+
std::vector<StepCall> step_calls;
155+
156+
// Set up expectations for text_decoder_runner step calls
157+
EXPECT_CALL(text_decoder_runner_, step(_, _))
158+
.Times(3) // Should be called 3 times for 3 chunks
159+
.WillRepeatedly(
160+
[&](executorch::extension::TensorPtr& tokens, int64_t pos) {
161+
// Extract token values from tensor
162+
std::vector<uint64_t> token_values;
163+
int64_t num_tokens = tokens->size(1);
164+
auto* token_data = tokens->const_data_ptr<int64_t>();
165+
for (int64_t i = 0; i < num_tokens; i++) {
166+
token_values.push_back(static_cast<uint64_t>(token_data[i]));
167+
}
168+
step_calls.push_back({token_values, pos});
169+
return Result<executorch::aten::Tensor>(tensor);
170+
});
185171

186172
// Call prefill
187173
auto result = prefiller->prefill(prompt_tokens, start_pos);
188174

189175
// Verify the result
190176
EXPECT_EQ(result.error(), Error::Ok);
191-
EXPECT_EQ(result.get(), 30); // Should return the token from the last chunk
192-
193-
// Verify that start_pos has been updated correctly
194-
EXPECT_EQ(start_pos, prompt_tokens.size());
195-
}
196-
197-
// Test that prefill() handles edge cases correctly
198-
TEST_F(TextPrefillerTest, PrefillHandlesEdgeCasesCorrectly) {
199-
// Create a spy TextPrefiller with max_seq_len = 1
200-
const int64_t max_seq_len = 1;
201-
auto prefiller = createMockTextPrefiller(max_seq_len);
202-
203-
// Create prompt tokens with size > max_seq_len
204-
std::vector<uint64_t> prompt_tokens = {1, 2, 3};
205-
int64_t start_pos = 5; // Non-zero starting position
206-
207-
// Set up expectations for prefill_chunk calls
208-
{
209-
InSequence seq;
210-
211-
// First chunk: token [1]
212-
EXPECT_CALL(*prefiller, prefill_chunk(_, _))
213-
.WillOnce([&](std::vector<uint64_t>& tokens, int64_t& pos) {
214-
EXPECT_EQ(tokens.size(), 1);
215-
EXPECT_EQ(tokens[0], 1);
216-
EXPECT_EQ(pos, 5);
217-
return Result<uint64_t>(10);
218-
});
219-
220-
// Second chunk: token [2]
221-
EXPECT_CALL(*prefiller, prefill_chunk(_, _))
222-
.WillOnce([&](std::vector<uint64_t>& tokens, int64_t& pos) {
223-
EXPECT_EQ(tokens.size(), 1);
224-
EXPECT_EQ(tokens[0], 2);
225-
EXPECT_EQ(pos, 6);
226-
return Result<uint64_t>(20);
227-
});
228-
229-
// Third chunk: token [3]
230-
EXPECT_CALL(*prefiller, prefill_chunk(_, _))
231-
.WillOnce([&](std::vector<uint64_t>& tokens, int64_t& pos) {
232-
EXPECT_EQ(tokens.size(), 1);
233-
EXPECT_EQ(tokens[0], 3);
234-
EXPECT_EQ(pos, 7);
235-
return Result<uint64_t>(30);
236-
});
237-
}
238-
239-
// Call prefill
240-
auto result = prefiller->prefill(prompt_tokens, start_pos);
241177

242-
// Verify the result
243-
EXPECT_EQ(result.error(), Error::Ok);
244-
EXPECT_EQ(result.get(), 30);
178+
// Verify that step was called 3 times with correct tokens and positions
179+
ASSERT_EQ(step_calls.size(), 3);
180+
181+
// First chunk: tokens [1, 2, 3] at position 0
182+
EXPECT_EQ(step_calls[0].tokens.size(), 3);
183+
EXPECT_EQ(step_calls[0].tokens[0], 1);
184+
EXPECT_EQ(step_calls[0].tokens[1], 2);
185+
EXPECT_EQ(step_calls[0].tokens[2], 3);
186+
EXPECT_EQ(step_calls[0].pos, 0);
187+
188+
// Second chunk: tokens [4, 5, 6] at position 3
189+
EXPECT_EQ(step_calls[1].tokens.size(), 3);
190+
EXPECT_EQ(step_calls[1].tokens[0], 4);
191+
EXPECT_EQ(step_calls[1].tokens[1], 5);
192+
EXPECT_EQ(step_calls[1].tokens[2], 6);
193+
EXPECT_EQ(step_calls[1].pos, 3);
194+
195+
// Third chunk: tokens [7, 8] at position 6
196+
EXPECT_EQ(step_calls[2].tokens.size(), 2);
197+
EXPECT_EQ(step_calls[2].tokens[0], 7);
198+
EXPECT_EQ(step_calls[2].tokens[1], 8);
199+
EXPECT_EQ(step_calls[2].pos, 6);
245200

246201
// Verify that start_pos has been updated correctly
247-
EXPECT_EQ(start_pos, 8); // 5 (initial) + 3 (tokens)
202+
EXPECT_EQ(start_pos, prompt_tokens.size());
248203
}
249204

250205
// Test that prefill() handles errors from prefill_chunk correctly
@@ -305,4 +260,119 @@ TEST_F(TextPrefillerTest, PrefillChunkWorksWithParallelPrefill) {
305260
// Verify that start_pos has been updated correctly
306261
EXPECT_EQ(start_pos, prompt_tokens.size());
307262
}
263+
// Test that prefill_chunk updates start_pos correctly with parallel prefill
264+
TEST_F(TextPrefillerTest, PrefillChunkUpdatesStartPosCorrectlyParallel) {
265+
// Create a TextPrefiller with parallel prefill enabled
266+
auto prefiller = createTextPrefiller(10, true, true);
267+
268+
// Set up expectations for the text decoder runner
269+
int64_t captured_pos = -1;
270+
EXPECT_CALL(text_decoder_runner_, step(_, _))
271+
.WillOnce([&](executorch::extension::TensorPtr& tokens, int64_t pos) {
272+
captured_pos = pos;
273+
// Verify tokens shape is [1, num_tokens]
274+
EXPECT_EQ(tokens->dim(), 2);
275+
EXPECT_EQ(tokens->size(0), 1);
276+
EXPECT_EQ(tokens->size(1), 3);
277+
return Result<executorch::aten::Tensor>(tensor);
278+
});
279+
280+
// Create prompt tokens
281+
std::vector<uint64_t> prompt_tokens = {1, 2, 3};
282+
int64_t start_pos = 5; // Non-zero starting position
283+
284+
// Call prefill_chunk directly
285+
auto result = prefiller->prefill_chunk(prompt_tokens, start_pos);
286+
287+
// Verify the result
288+
EXPECT_EQ(result.error(), Error::Ok);
289+
290+
// Verify that step was called with the original start_pos
291+
EXPECT_EQ(captured_pos, 5);
292+
293+
// Verify that start_pos has been updated by the number of tokens
294+
// This is the key test: start_pos should be updated exactly once
295+
EXPECT_EQ(start_pos, 8); // 5 + 3 tokens
296+
}
297+
298+
// Test that prefill_chunk updates start_pos correctly with sequential prefill
299+
TEST_F(TextPrefillerTest, PrefillChunkUpdatesStartPosCorrectlySequential) {
300+
// Create a TextPrefiller with sequential prefill (parallel disabled)
301+
auto prefiller = createTextPrefiller(10, true, false);
302+
303+
// Track all positions passed to step
304+
std::vector<int64_t> captured_positions;
305+
EXPECT_CALL(text_decoder_runner_, step(_, _))
306+
.Times(3)
307+
.WillRepeatedly(
308+
[&](executorch::extension::TensorPtr& tokens, int64_t pos) {
309+
captured_positions.push_back(pos);
310+
// Verify tokens shape is [1, 1] for sequential prefill
311+
EXPECT_EQ(tokens->dim(), 2);
312+
EXPECT_EQ(tokens->size(0), 1);
313+
EXPECT_EQ(tokens->size(1), 1);
314+
return Result<executorch::aten::Tensor>(tensor);
315+
});
316+
317+
// Create prompt tokens
318+
std::vector<uint64_t> prompt_tokens = {1, 2, 3};
319+
int64_t start_pos = 10; // Non-zero starting position
320+
321+
// Call prefill_chunk directly
322+
auto result = prefiller->prefill_chunk(prompt_tokens, start_pos);
323+
324+
// Verify the result
325+
EXPECT_EQ(result.error(), Error::Ok);
326+
327+
// Verify that step was called with incrementing positions
328+
ASSERT_EQ(captured_positions.size(), 3);
329+
EXPECT_EQ(captured_positions[0], 10); // First token at initial start_pos
330+
EXPECT_EQ(captured_positions[1], 11); // Second token at start_pos + 1
331+
EXPECT_EQ(captured_positions[2], 12); // Third token at start_pos + 2
332+
333+
// Verify that start_pos has been updated by the number of tokens
334+
// This is the key test: start_pos should be updated exactly once per token
335+
EXPECT_EQ(start_pos, 13); // 10 + 3 tokens
336+
}
337+
338+
// Test that prefill with chunking updates start_pos correctly across chunks.
339+
// This test would have caught the bug where start_pos was being updated twice.
340+
TEST_F(
341+
TextPrefillerTest,
342+
PrefillWithChunkingUpdatesStartPosCorrectlyAcrossChunks) {
343+
// Create a TextPrefiller with max_seq_len = 3 and parallel prefill
344+
auto prefiller = createTextPrefiller(3, true, true);
345+
346+
// Track all positions passed to step
347+
std::vector<int64_t> captured_positions;
348+
EXPECT_CALL(text_decoder_runner_, step(_, _))
349+
.Times(3) // Should be called 3 times: [1,2,3], [4,5,6], [7,8]
350+
.WillRepeatedly(
351+
[&](executorch::extension::TensorPtr& tokens, int64_t pos) {
352+
captured_positions.push_back(pos);
353+
return Result<executorch::aten::Tensor>(tensor);
354+
});
355+
356+
// Create prompt tokens that exceed max_seq_len
357+
std::vector<uint64_t> prompt_tokens = {1, 2, 3, 4, 5, 6, 7, 8};
358+
int64_t start_pos = 100; // Non-zero starting position
359+
360+
// Call prefill (which will chunk internally)
361+
auto result = prefiller->prefill(prompt_tokens, start_pos);
362+
363+
// Verify the result
364+
EXPECT_EQ(result.error(), Error::Ok);
365+
366+
// Verify that step was called with correct positions for each chunk
367+
// If start_pos were updated twice (the bug), these would be wrong
368+
ASSERT_EQ(captured_positions.size(), 3);
369+
EXPECT_EQ(captured_positions[0], 100); // Chunk 1: tokens [1,2,3]
370+
EXPECT_EQ(captured_positions[1], 103); // Chunk 2: tokens [4,5,6]
371+
EXPECT_EQ(captured_positions[2], 106); // Chunk 3: tokens [7,8]
372+
373+
// Verify that final start_pos is correct
374+
// This is the key test for the bug: start_pos should be exactly
375+
// initial_pos + num_tokens, not double-incremented
376+
EXPECT_EQ(start_pos, 108); // 100 + 8 tokens
377+
}
308378
} // namespace

extension/llm/runner/text_llm_runner.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ Error TextLLMRunner::warmup(const std::string& prompt, int32_t max_new_tokens) {
228228
Error err = generate(prompt, config);
229229

230230
// Reset stats after warmup, not resetting the std::unique_ptr!
231-
stats_->reset();
231+
reset();
232232
return err;
233233
}
234234

extension/llm/runner/text_prefiller.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ ::executorch::runtime::Result<uint64_t> TextPrefiller::prefill(
5959
ET_CHECK_OK_OR_RETURN_ERROR(chunk_result.error());
6060
cur_token = chunk_result.get();
6161

62-
start_pos += num_tokens_to_prefill_with;
6362
num_tokens_to_process += num_tokens_to_prefill_with;
6463
}
6564

0 commit comments

Comments
 (0)