@@ -138,113 +138,68 @@ TEST_F(TextPrefillerTest, PrefillCallsPrefillChunkOnceWhenPromptFits) {
138138TEST_F (
139139    TextPrefillerTest,
140140    PrefillCallsPrefillChunkMultipleTimesWhenPromptExceedsMaxLen) {
141-   //  Create a spy  TextPrefiller with max_seq_len = 3
141+   //  Create a real  TextPrefiller with max_seq_len = 3 and parallel prefill 
142142  const  int64_t  max_seq_len = 3 ;
143-   auto  prefiller = createMockTextPrefiller (max_seq_len);
143+   auto  prefiller = createTextPrefiller (max_seq_len,  true ,  true );
144144
145145  //  Create prompt tokens with size > max_seq_len
146146  std::vector<uint64_t > prompt_tokens = {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 };
147147  int64_t  start_pos = 0 ;
148148
149-   //  Set up expectations for prefill_chunk calls
150-   {
151-     InSequence seq; //  Ensure calls happen in the expected order
152- 
153-     //  First chunk: tokens [1, 2, 3]
154-     EXPECT_CALL (*prefiller, prefill_chunk (_, _))
155-         .WillOnce ([&](std::vector<uint64_t >& tokens, int64_t & pos) {
156-           EXPECT_EQ (tokens.size (), 3 );
157-           EXPECT_EQ (tokens[0 ], 1 );
158-           EXPECT_EQ (tokens[1 ], 2 );
159-           EXPECT_EQ (tokens[2 ], 3 );
160-           EXPECT_EQ (pos, 0 );
161-           return  Result<uint64_t >(10 );
162-         });
163- 
164-     //  Second chunk: tokens [4, 5, 6]
165-     EXPECT_CALL (*prefiller, prefill_chunk (_, _))
166-         .WillOnce ([&](std::vector<uint64_t >& tokens, int64_t & pos) {
167-           EXPECT_EQ (tokens.size (), 3 );
168-           EXPECT_EQ (tokens[0 ], 4 );
169-           EXPECT_EQ (tokens[1 ], 5 );
170-           EXPECT_EQ (tokens[2 ], 6 );
171-           EXPECT_EQ (pos, 3 );
172-           return  Result<uint64_t >(20 );
173-         });
174- 
175-     //  Third chunk: tokens [7, 8]
176-     EXPECT_CALL (*prefiller, prefill_chunk (_, _))
177-         .WillOnce ([&](std::vector<uint64_t >& tokens, int64_t & pos) {
178-           EXPECT_EQ (tokens.size (), 2 );
179-           EXPECT_EQ (tokens[0 ], 7 );
180-           EXPECT_EQ (tokens[1 ], 8 );
181-           EXPECT_EQ (pos, 6 );
182-           return  Result<uint64_t >(30 );
183-         });
184-   }
149+   //  Track all tokens and positions passed to text_decoder_runner step
150+   struct  StepCall  {
151+     std::vector<uint64_t > tokens;
152+     int64_t  pos;
153+   };
154+   std::vector<StepCall> step_calls;
155+ 
156+   //  Set up expectations for text_decoder_runner step calls
157+   EXPECT_CALL (text_decoder_runner_, step (_, _))
158+       .Times (3 ) //  Should be called 3 times for 3 chunks
159+       .WillRepeatedly (
160+           [&](executorch::extension::TensorPtr& tokens, int64_t  pos) {
161+             //  Extract token values from tensor
162+             std::vector<uint64_t > token_values;
163+             int64_t  num_tokens = tokens->size (1 );
164+             auto * token_data = tokens->const_data_ptr <int64_t >();
165+             for  (int64_t  i = 0 ; i < num_tokens; i++) {
166+               token_values.push_back (static_cast <uint64_t >(token_data[i]));
167+             }
168+             step_calls.push_back ({token_values, pos});
169+             return  Result<executorch::aten::Tensor>(tensor);
170+           });
185171
186172  //  Call prefill
187173  auto  result = prefiller->prefill (prompt_tokens, start_pos);
188174
189175  //  Verify the result
190176  EXPECT_EQ (result.error (), Error::Ok);
191-   EXPECT_EQ (result.get (), 30 ); //  Should return the token from the last chunk
192- 
193-   //  Verify that start_pos has been updated correctly
194-   EXPECT_EQ (start_pos, prompt_tokens.size ());
195- }
196- 
197- //  Test that prefill() handles edge cases correctly
198- TEST_F (TextPrefillerTest, PrefillHandlesEdgeCasesCorrectly) {
199-   //  Create a spy TextPrefiller with max_seq_len = 1
200-   const  int64_t  max_seq_len = 1 ;
201-   auto  prefiller = createMockTextPrefiller (max_seq_len);
202- 
203-   //  Create prompt tokens with size > max_seq_len
204-   std::vector<uint64_t > prompt_tokens = {1 , 2 , 3 };
205-   int64_t  start_pos = 5 ; //  Non-zero starting position
206- 
207-   //  Set up expectations for prefill_chunk calls
208-   {
209-     InSequence seq;
210- 
211-     //  First chunk: token [1]
212-     EXPECT_CALL (*prefiller, prefill_chunk (_, _))
213-         .WillOnce ([&](std::vector<uint64_t >& tokens, int64_t & pos) {
214-           EXPECT_EQ (tokens.size (), 1 );
215-           EXPECT_EQ (tokens[0 ], 1 );
216-           EXPECT_EQ (pos, 5 );
217-           return  Result<uint64_t >(10 );
218-         });
219- 
220-     //  Second chunk: token [2]
221-     EXPECT_CALL (*prefiller, prefill_chunk (_, _))
222-         .WillOnce ([&](std::vector<uint64_t >& tokens, int64_t & pos) {
223-           EXPECT_EQ (tokens.size (), 1 );
224-           EXPECT_EQ (tokens[0 ], 2 );
225-           EXPECT_EQ (pos, 6 );
226-           return  Result<uint64_t >(20 );
227-         });
228- 
229-     //  Third chunk: token [3]
230-     EXPECT_CALL (*prefiller, prefill_chunk (_, _))
231-         .WillOnce ([&](std::vector<uint64_t >& tokens, int64_t & pos) {
232-           EXPECT_EQ (tokens.size (), 1 );
233-           EXPECT_EQ (tokens[0 ], 3 );
234-           EXPECT_EQ (pos, 7 );
235-           return  Result<uint64_t >(30 );
236-         });
237-   }
238- 
239-   //  Call prefill
240-   auto  result = prefiller->prefill (prompt_tokens, start_pos);
241177
242-   //  Verify the result
243-   EXPECT_EQ (result.error (), Error::Ok);
244-   EXPECT_EQ (result.get (), 30 );
178+   //  Verify that step was called 3 times with correct tokens and positions
179+   ASSERT_EQ (step_calls.size (), 3 );
180+ 
181+   //  First chunk: tokens [1, 2, 3] at position 0
182+   EXPECT_EQ (step_calls[0 ].tokens .size (), 3 );
183+   EXPECT_EQ (step_calls[0 ].tokens [0 ], 1 );
184+   EXPECT_EQ (step_calls[0 ].tokens [1 ], 2 );
185+   EXPECT_EQ (step_calls[0 ].tokens [2 ], 3 );
186+   EXPECT_EQ (step_calls[0 ].pos , 0 );
187+ 
188+   //  Second chunk: tokens [4, 5, 6] at position 3
189+   EXPECT_EQ (step_calls[1 ].tokens .size (), 3 );
190+   EXPECT_EQ (step_calls[1 ].tokens [0 ], 4 );
191+   EXPECT_EQ (step_calls[1 ].tokens [1 ], 5 );
192+   EXPECT_EQ (step_calls[1 ].tokens [2 ], 6 );
193+   EXPECT_EQ (step_calls[1 ].pos , 3 );
194+ 
195+   //  Third chunk: tokens [7, 8] at position 6
196+   EXPECT_EQ (step_calls[2 ].tokens .size (), 2 );
197+   EXPECT_EQ (step_calls[2 ].tokens [0 ], 7 );
198+   EXPECT_EQ (step_calls[2 ].tokens [1 ], 8 );
199+   EXPECT_EQ (step_calls[2 ].pos , 6 );
245200
246201  //  Verify that start_pos has been updated correctly
247-   EXPECT_EQ (start_pos, 8 );  //  5 (initial) + 3 (tokens) 
202+   EXPECT_EQ (start_pos, prompt_tokens. size ()); 
248203}
249204
250205//  Test that prefill() handles errors from prefill_chunk correctly
@@ -305,4 +260,119 @@ TEST_F(TextPrefillerTest, PrefillChunkWorksWithParallelPrefill) {
305260  //  Verify that start_pos has been updated correctly
306261  EXPECT_EQ (start_pos, prompt_tokens.size ());
307262}
263+ //  Test that prefill_chunk updates start_pos correctly with parallel prefill
264+ TEST_F (TextPrefillerTest, PrefillChunkUpdatesStartPosCorrectlyParallel) {
265+   //  Create a TextPrefiller with parallel prefill enabled
266+   auto  prefiller = createTextPrefiller (10 , true , true );
267+ 
268+   //  Set up expectations for the text decoder runner
269+   int64_t  captured_pos = -1 ;
270+   EXPECT_CALL (text_decoder_runner_, step (_, _))
271+       .WillOnce ([&](executorch::extension::TensorPtr& tokens, int64_t  pos) {
272+         captured_pos = pos;
273+         //  Verify tokens shape is [1, num_tokens]
274+         EXPECT_EQ (tokens->dim (), 2 );
275+         EXPECT_EQ (tokens->size (0 ), 1 );
276+         EXPECT_EQ (tokens->size (1 ), 3 );
277+         return  Result<executorch::aten::Tensor>(tensor);
278+       });
279+ 
280+   //  Create prompt tokens
281+   std::vector<uint64_t > prompt_tokens = {1 , 2 , 3 };
282+   int64_t  start_pos = 5 ; //  Non-zero starting position
283+ 
284+   //  Call prefill_chunk directly
285+   auto  result = prefiller->prefill_chunk (prompt_tokens, start_pos);
286+ 
287+   //  Verify the result
288+   EXPECT_EQ (result.error (), Error::Ok);
289+ 
290+   //  Verify that step was called with the original start_pos
291+   EXPECT_EQ (captured_pos, 5 );
292+ 
293+   //  Verify that start_pos has been updated by the number of tokens
294+   //  This is the key test: start_pos should be updated exactly once
295+   EXPECT_EQ (start_pos, 8 ); //  5 + 3 tokens
296+ }
297+ 
298+ //  Test that prefill_chunk updates start_pos correctly with sequential prefill
299+ TEST_F (TextPrefillerTest, PrefillChunkUpdatesStartPosCorrectlySequential) {
300+   //  Create a TextPrefiller with sequential prefill (parallel disabled)
301+   auto  prefiller = createTextPrefiller (10 , true , false );
302+ 
303+   //  Track all positions passed to step
304+   std::vector<int64_t > captured_positions;
305+   EXPECT_CALL (text_decoder_runner_, step (_, _))
306+       .Times (3 )
307+       .WillRepeatedly (
308+           [&](executorch::extension::TensorPtr& tokens, int64_t  pos) {
309+             captured_positions.push_back (pos);
310+             //  Verify tokens shape is [1, 1] for sequential prefill
311+             EXPECT_EQ (tokens->dim (), 2 );
312+             EXPECT_EQ (tokens->size (0 ), 1 );
313+             EXPECT_EQ (tokens->size (1 ), 1 );
314+             return  Result<executorch::aten::Tensor>(tensor);
315+           });
316+ 
317+   //  Create prompt tokens
318+   std::vector<uint64_t > prompt_tokens = {1 , 2 , 3 };
319+   int64_t  start_pos = 10 ; //  Non-zero starting position
320+ 
321+   //  Call prefill_chunk directly
322+   auto  result = prefiller->prefill_chunk (prompt_tokens, start_pos);
323+ 
324+   //  Verify the result
325+   EXPECT_EQ (result.error (), Error::Ok);
326+ 
327+   //  Verify that step was called with incrementing positions
328+   ASSERT_EQ (captured_positions.size (), 3 );
329+   EXPECT_EQ (captured_positions[0 ], 10 ); //  First token at initial start_pos
330+   EXPECT_EQ (captured_positions[1 ], 11 ); //  Second token at start_pos + 1
331+   EXPECT_EQ (captured_positions[2 ], 12 ); //  Third token at start_pos + 2
332+ 
333+   //  Verify that start_pos has been updated by the number of tokens
334+   //  This is the key test: start_pos should be updated exactly once per token
335+   EXPECT_EQ (start_pos, 13 ); //  10 + 3 tokens
336+ }
337+ 
338+ //  Test that prefill with chunking updates start_pos correctly across chunks.
339+ //  This test would have caught the bug where start_pos was being updated twice.
340+ TEST_F (
341+     TextPrefillerTest,
342+     PrefillWithChunkingUpdatesStartPosCorrectlyAcrossChunks) {
343+   //  Create a TextPrefiller with max_seq_len = 3 and parallel prefill
344+   auto  prefiller = createTextPrefiller (3 , true , true );
345+ 
346+   //  Track all positions passed to step
347+   std::vector<int64_t > captured_positions;
348+   EXPECT_CALL (text_decoder_runner_, step (_, _))
349+       .Times (3 ) //  Should be called 3 times: [1,2,3], [4,5,6], [7,8]
350+       .WillRepeatedly (
351+           [&](executorch::extension::TensorPtr& tokens, int64_t  pos) {
352+             captured_positions.push_back (pos);
353+             return  Result<executorch::aten::Tensor>(tensor);
354+           });
355+ 
356+   //  Create prompt tokens that exceed max_seq_len
357+   std::vector<uint64_t > prompt_tokens = {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 };
358+   int64_t  start_pos = 100 ; //  Non-zero starting position
359+ 
360+   //  Call prefill (which will chunk internally)
361+   auto  result = prefiller->prefill (prompt_tokens, start_pos);
362+ 
363+   //  Verify the result
364+   EXPECT_EQ (result.error (), Error::Ok);
365+ 
366+   //  Verify that step was called with correct positions for each chunk
367+   //  If start_pos were updated twice (the bug), these would be wrong
368+   ASSERT_EQ (captured_positions.size (), 3 );
369+   EXPECT_EQ (captured_positions[0 ], 100 ); //  Chunk 1: tokens [1,2,3]
370+   EXPECT_EQ (captured_positions[1 ], 103 ); //  Chunk 2: tokens [4,5,6]
371+   EXPECT_EQ (captured_positions[2 ], 106 ); //  Chunk 3: tokens [7,8]
372+ 
373+   //  Verify that final start_pos is correct
374+   //  This is the key test for the bug: start_pos should be exactly
375+   //  initial_pos + num_tokens, not double-incremented
376+   EXPECT_EQ (start_pos, 108 ); //  100 + 8 tokens
377+ }
308378} //  namespace
0 commit comments