Skip to content

Commit a11d555

Browse files
authored
Fix text prefiller start pos being updated twice
Differential Revision: D86013259 Pull Request resolved: #15509
1 parent cc72b35 commit a11d555

File tree

4 files changed

+187
-100
lines changed

4 files changed

+187
-100
lines changed

examples/models/llama/main.cpp

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,12 @@ DEFINE_double(
4040
DEFINE_int32(
4141
seq_len,
4242
128,
43-
"Total number of tokens to generate (prompt + output). Defaults to max_seq_len. If the number of input tokens + seq_len > max_seq_len, the output will be truncated to max_seq_len tokens.");
43+
"DEPRECATED: Please use max_seq_len instead. Total number of tokens to generate (prompt + output). Defaults to max_seq_len. If the number of input tokens + seq_len > max_seq_len, the output will be truncated to max_seq_len tokens.");
44+
45+
DEFINE_int32(
46+
max_new_tokens,
47+
-1,
48+
"Total number of tokens to generate, excluding the prompt, will be capped by max_seq_len - # prompt tokens.");
4449

4550
DEFINE_int32(
4651
cpu_threads,
@@ -122,20 +127,33 @@ int32_t main(int32_t argc, char** argv) {
122127
}
123128

124129
if (warmup) {
125-
auto error = runner->warmup(prompt, /*max_new_tokens=*/seq_len);
130+
int32_t warmup_max_new_tokens =
131+
FLAGS_max_new_tokens != -1 ? FLAGS_max_new_tokens : seq_len;
132+
auto error =
133+
runner->warmup(prompt, /*max_new_tokens=*/warmup_max_new_tokens);
126134
if (error != executorch::runtime::Error::Ok) {
127135
ET_LOG(Error, "Failed to warmup llama runner");
128136
return 1;
129137
}
130-
// reset kv cache pos to 0
131-
runner->reset();
132138
}
133139
// generate
134140
executorch::extension::llm::GenerationConfig config{
135-
.seq_len = seq_len, .temperature = temperature};
141+
.temperature = temperature};
142+
143+
if (FLAGS_max_new_tokens != -1) {
144+
config.max_new_tokens = FLAGS_max_new_tokens;
145+
} else {
146+
ET_LOG(
147+
Info,
148+
"max_new_tokens not provided, falling back to seq_len=%d. "
149+
"Consider using --max_new_tokens instead of --seq_len for specifying generation length.",
150+
seq_len);
151+
config.seq_len = seq_len;
152+
}
153+
136154
auto error = runner->generate(prompt, config);
137155
if (error != executorch::runtime::Error::Ok) {
138-
ET_LOG(Error, "Failed to warmup llama runner");
156+
ET_LOG(Error, "Failed to run llama runner");
139157
return 1;
140158
}
141159

extension/llm/runner/test/test_text_prefiller.cpp

Lines changed: 162 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -138,113 +138,68 @@ TEST_F(TextPrefillerTest, PrefillCallsPrefillChunkOnceWhenPromptFits) {
138138
TEST_F(
139139
TextPrefillerTest,
140140
PrefillCallsPrefillChunkMultipleTimesWhenPromptExceedsMaxLen) {
141-
// Create a spy TextPrefiller with max_seq_len = 3
141+
// Create a real TextPrefiller with max_seq_len = 3 and parallel prefill
142142
const int64_t max_seq_len = 3;
143-
auto prefiller = createMockTextPrefiller(max_seq_len);
143+
auto prefiller = createTextPrefiller(max_seq_len, true, true);
144144

145145
// Create prompt tokens with size > max_seq_len
146146
std::vector<uint64_t> prompt_tokens = {1, 2, 3, 4, 5, 6, 7, 8};
147147
int64_t start_pos = 0;
148148

149-
// Set up expectations for prefill_chunk calls
150-
{
151-
InSequence seq; // Ensure calls happen in the expected order
152-
153-
// First chunk: tokens [1, 2, 3]
154-
EXPECT_CALL(*prefiller, prefill_chunk(_, _))
155-
.WillOnce([&](std::vector<uint64_t>& tokens, int64_t& pos) {
156-
EXPECT_EQ(tokens.size(), 3);
157-
EXPECT_EQ(tokens[0], 1);
158-
EXPECT_EQ(tokens[1], 2);
159-
EXPECT_EQ(tokens[2], 3);
160-
EXPECT_EQ(pos, 0);
161-
return Result<uint64_t>(10);
162-
});
163-
164-
// Second chunk: tokens [4, 5, 6]
165-
EXPECT_CALL(*prefiller, prefill_chunk(_, _))
166-
.WillOnce([&](std::vector<uint64_t>& tokens, int64_t& pos) {
167-
EXPECT_EQ(tokens.size(), 3);
168-
EXPECT_EQ(tokens[0], 4);
169-
EXPECT_EQ(tokens[1], 5);
170-
EXPECT_EQ(tokens[2], 6);
171-
EXPECT_EQ(pos, 3);
172-
return Result<uint64_t>(20);
173-
});
174-
175-
// Third chunk: tokens [7, 8]
176-
EXPECT_CALL(*prefiller, prefill_chunk(_, _))
177-
.WillOnce([&](std::vector<uint64_t>& tokens, int64_t& pos) {
178-
EXPECT_EQ(tokens.size(), 2);
179-
EXPECT_EQ(tokens[0], 7);
180-
EXPECT_EQ(tokens[1], 8);
181-
EXPECT_EQ(pos, 6);
182-
return Result<uint64_t>(30);
183-
});
184-
}
149+
// Track all tokens and positions passed to text_decoder_runner step
150+
struct StepCall {
151+
std::vector<uint64_t> tokens;
152+
int64_t pos;
153+
};
154+
std::vector<StepCall> step_calls;
155+
156+
// Set up expectations for text_decoder_runner step calls
157+
EXPECT_CALL(text_decoder_runner_, step(_, _))
158+
.Times(3) // Should be called 3 times for 3 chunks
159+
.WillRepeatedly(
160+
[&](executorch::extension::TensorPtr& tokens, int64_t pos) {
161+
// Extract token values from tensor
162+
std::vector<uint64_t> token_values;
163+
int64_t num_tokens = tokens->size(1);
164+
auto* token_data = tokens->const_data_ptr<int64_t>();
165+
for (int64_t i = 0; i < num_tokens; i++) {
166+
token_values.push_back(static_cast<uint64_t>(token_data[i]));
167+
}
168+
step_calls.push_back({token_values, pos});
169+
return Result<executorch::aten::Tensor>(tensor);
170+
});
185171

186172
// Call prefill
187173
auto result = prefiller->prefill(prompt_tokens, start_pos);
188174

189175
// Verify the result
190176
EXPECT_EQ(result.error(), Error::Ok);
191-
EXPECT_EQ(result.get(), 30); // Should return the token from the last chunk
192-
193-
// Verify that start_pos has been updated correctly
194-
EXPECT_EQ(start_pos, prompt_tokens.size());
195-
}
196-
197-
// Test that prefill() handles edge cases correctly
198-
TEST_F(TextPrefillerTest, PrefillHandlesEdgeCasesCorrectly) {
199-
// Create a spy TextPrefiller with max_seq_len = 1
200-
const int64_t max_seq_len = 1;
201-
auto prefiller = createMockTextPrefiller(max_seq_len);
202-
203-
// Create prompt tokens with size > max_seq_len
204-
std::vector<uint64_t> prompt_tokens = {1, 2, 3};
205-
int64_t start_pos = 5; // Non-zero starting position
206-
207-
// Set up expectations for prefill_chunk calls
208-
{
209-
InSequence seq;
210-
211-
// First chunk: token [1]
212-
EXPECT_CALL(*prefiller, prefill_chunk(_, _))
213-
.WillOnce([&](std::vector<uint64_t>& tokens, int64_t& pos) {
214-
EXPECT_EQ(tokens.size(), 1);
215-
EXPECT_EQ(tokens[0], 1);
216-
EXPECT_EQ(pos, 5);
217-
return Result<uint64_t>(10);
218-
});
219-
220-
// Second chunk: token [2]
221-
EXPECT_CALL(*prefiller, prefill_chunk(_, _))
222-
.WillOnce([&](std::vector<uint64_t>& tokens, int64_t& pos) {
223-
EXPECT_EQ(tokens.size(), 1);
224-
EXPECT_EQ(tokens[0], 2);
225-
EXPECT_EQ(pos, 6);
226-
return Result<uint64_t>(20);
227-
});
228-
229-
// Third chunk: token [3]
230-
EXPECT_CALL(*prefiller, prefill_chunk(_, _))
231-
.WillOnce([&](std::vector<uint64_t>& tokens, int64_t& pos) {
232-
EXPECT_EQ(tokens.size(), 1);
233-
EXPECT_EQ(tokens[0], 3);
234-
EXPECT_EQ(pos, 7);
235-
return Result<uint64_t>(30);
236-
});
237-
}
238-
239-
// Call prefill
240-
auto result = prefiller->prefill(prompt_tokens, start_pos);
241177

242-
// Verify the result
243-
EXPECT_EQ(result.error(), Error::Ok);
244-
EXPECT_EQ(result.get(), 30);
178+
// Verify that step was called 3 times with correct tokens and positions
179+
ASSERT_EQ(step_calls.size(), 3);
180+
181+
// First chunk: tokens [1, 2, 3] at position 0
182+
EXPECT_EQ(step_calls[0].tokens.size(), 3);
183+
EXPECT_EQ(step_calls[0].tokens[0], 1);
184+
EXPECT_EQ(step_calls[0].tokens[1], 2);
185+
EXPECT_EQ(step_calls[0].tokens[2], 3);
186+
EXPECT_EQ(step_calls[0].pos, 0);
187+
188+
// Second chunk: tokens [4, 5, 6] at position 3
189+
EXPECT_EQ(step_calls[1].tokens.size(), 3);
190+
EXPECT_EQ(step_calls[1].tokens[0], 4);
191+
EXPECT_EQ(step_calls[1].tokens[1], 5);
192+
EXPECT_EQ(step_calls[1].tokens[2], 6);
193+
EXPECT_EQ(step_calls[1].pos, 3);
194+
195+
// Third chunk: tokens [7, 8] at position 6
196+
EXPECT_EQ(step_calls[2].tokens.size(), 2);
197+
EXPECT_EQ(step_calls[2].tokens[0], 7);
198+
EXPECT_EQ(step_calls[2].tokens[1], 8);
199+
EXPECT_EQ(step_calls[2].pos, 6);
245200

246201
// Verify that start_pos has been updated correctly
247-
EXPECT_EQ(start_pos, 8); // 5 (initial) + 3 (tokens)
202+
EXPECT_EQ(start_pos, prompt_tokens.size());
248203
}
249204

250205
// Test that prefill() handles errors from prefill_chunk correctly
@@ -305,4 +260,119 @@ TEST_F(TextPrefillerTest, PrefillChunkWorksWithParallelPrefill) {
305260
// Verify that start_pos has been updated correctly
306261
EXPECT_EQ(start_pos, prompt_tokens.size());
307262
}
263+
// Test that prefill_chunk updates start_pos correctly with parallel prefill
264+
TEST_F(TextPrefillerTest, PrefillChunkUpdatesStartPosCorrectlyParallel) {
265+
// Create a TextPrefiller with parallel prefill enabled
266+
auto prefiller = createTextPrefiller(10, true, true);
267+
268+
// Set up expectations for the text decoder runner
269+
int64_t captured_pos = -1;
270+
EXPECT_CALL(text_decoder_runner_, step(_, _))
271+
.WillOnce([&](executorch::extension::TensorPtr& tokens, int64_t pos) {
272+
captured_pos = pos;
273+
// Verify tokens shape is [1, num_tokens]
274+
EXPECT_EQ(tokens->dim(), 2);
275+
EXPECT_EQ(tokens->size(0), 1);
276+
EXPECT_EQ(tokens->size(1), 3);
277+
return Result<executorch::aten::Tensor>(tensor);
278+
});
279+
280+
// Create prompt tokens
281+
std::vector<uint64_t> prompt_tokens = {1, 2, 3};
282+
int64_t start_pos = 5; // Non-zero starting position
283+
284+
// Call prefill_chunk directly
285+
auto result = prefiller->prefill_chunk(prompt_tokens, start_pos);
286+
287+
// Verify the result
288+
EXPECT_EQ(result.error(), Error::Ok);
289+
290+
// Verify that step was called with the original start_pos
291+
EXPECT_EQ(captured_pos, 5);
292+
293+
// Verify that start_pos has been updated by the number of tokens
294+
// This is the key test: start_pos should be updated exactly once
295+
EXPECT_EQ(start_pos, 8); // 5 + 3 tokens
296+
}
297+
298+
// Test that prefill_chunk updates start_pos correctly with sequential prefill
299+
TEST_F(TextPrefillerTest, PrefillChunkUpdatesStartPosCorrectlySequential) {
300+
// Create a TextPrefiller with sequential prefill (parallel disabled)
301+
auto prefiller = createTextPrefiller(10, true, false);
302+
303+
// Track all positions passed to step
304+
std::vector<int64_t> captured_positions;
305+
EXPECT_CALL(text_decoder_runner_, step(_, _))
306+
.Times(3)
307+
.WillRepeatedly(
308+
[&](executorch::extension::TensorPtr& tokens, int64_t pos) {
309+
captured_positions.push_back(pos);
310+
// Verify tokens shape is [1, 1] for sequential prefill
311+
EXPECT_EQ(tokens->dim(), 2);
312+
EXPECT_EQ(tokens->size(0), 1);
313+
EXPECT_EQ(tokens->size(1), 1);
314+
return Result<executorch::aten::Tensor>(tensor);
315+
});
316+
317+
// Create prompt tokens
318+
std::vector<uint64_t> prompt_tokens = {1, 2, 3};
319+
int64_t start_pos = 10; // Non-zero starting position
320+
321+
// Call prefill_chunk directly
322+
auto result = prefiller->prefill_chunk(prompt_tokens, start_pos);
323+
324+
// Verify the result
325+
EXPECT_EQ(result.error(), Error::Ok);
326+
327+
// Verify that step was called with incrementing positions
328+
ASSERT_EQ(captured_positions.size(), 3);
329+
EXPECT_EQ(captured_positions[0], 10); // First token at initial start_pos
330+
EXPECT_EQ(captured_positions[1], 11); // Second token at start_pos + 1
331+
EXPECT_EQ(captured_positions[2], 12); // Third token at start_pos + 2
332+
333+
// Verify that start_pos has been updated by the number of tokens
334+
// This is the key test: start_pos should be updated exactly once per token
335+
EXPECT_EQ(start_pos, 13); // 10 + 3 tokens
336+
}
337+
338+
// Test that prefill with chunking updates start_pos correctly across chunks.
339+
// This test would have caught the bug where start_pos was being updated twice.
340+
TEST_F(
341+
TextPrefillerTest,
342+
PrefillWithChunkingUpdatesStartPosCorrectlyAcrossChunks) {
343+
// Create a TextPrefiller with max_seq_len = 3 and parallel prefill
344+
auto prefiller = createTextPrefiller(3, true, true);
345+
346+
// Track all positions passed to step
347+
std::vector<int64_t> captured_positions;
348+
EXPECT_CALL(text_decoder_runner_, step(_, _))
349+
.Times(3) // Should be called 3 times: [1,2,3], [4,5,6], [7,8]
350+
.WillRepeatedly(
351+
[&](executorch::extension::TensorPtr& tokens, int64_t pos) {
352+
captured_positions.push_back(pos);
353+
return Result<executorch::aten::Tensor>(tensor);
354+
});
355+
356+
// Create prompt tokens that exceed max_seq_len
357+
std::vector<uint64_t> prompt_tokens = {1, 2, 3, 4, 5, 6, 7, 8};
358+
int64_t start_pos = 100; // Non-zero starting position
359+
360+
// Call prefill (which will chunk internally)
361+
auto result = prefiller->prefill(prompt_tokens, start_pos);
362+
363+
// Verify the result
364+
EXPECT_EQ(result.error(), Error::Ok);
365+
366+
// Verify that step was called with correct positions for each chunk
367+
// If start_pos were updated twice (the bug), these would be wrong
368+
ASSERT_EQ(captured_positions.size(), 3);
369+
EXPECT_EQ(captured_positions[0], 100); // Chunk 1: tokens [1,2,3]
370+
EXPECT_EQ(captured_positions[1], 103); // Chunk 2: tokens [4,5,6]
371+
EXPECT_EQ(captured_positions[2], 106); // Chunk 3: tokens [7,8]
372+
373+
// Verify that final start_pos is correct
374+
// This is the key test for the bug: start_pos should be exactly
375+
// initial_pos + num_tokens, not double-incremented
376+
EXPECT_EQ(start_pos, 108); // 100 + 8 tokens
377+
}
308378
} // namespace

extension/llm/runner/text_llm_runner.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ Error TextLLMRunner::warmup(const std::string& prompt, int32_t max_new_tokens) {
247247
Error err = generate(prompt, config);
248248

249249
// Reset stats after warmup, not resetting the std::unique_ptr!
250-
stats_->reset();
250+
reset();
251251
return err;
252252
}
253253

extension/llm/runner/text_prefiller.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ ::executorch::runtime::Result<uint64_t> TextPrefiller::prefill(
5959
ET_CHECK_OK_OR_RETURN_ERROR(chunk_result.error());
6060
cur_token = chunk_result.get();
6161

62-
start_pos += num_tokens_to_prefill_with;
6362
num_tokens_to_process += num_tokens_to_prefill_with;
6463
}
6564

0 commit comments

Comments
 (0)