@@ -195,9 +195,9 @@ def propose(
195195 num_sampled : torch .Tensor ,
196196 # [num_reqs]
197197 num_rejected : torch .Tensor ,
198- # [num_reqs ]
198+ # [max_num_reqs ]
199199 last_sampled : torch .Tensor ,
200- # [num_reqs ]
200+ # [max_num_reqs ]
201201 next_prefill_tokens : torch .Tensor ,
202202 # [max_num_reqs]
203203 temperature : torch .Tensor ,
@@ -320,6 +320,7 @@ def _prepare_eagle_inputs_kernel(
320320 eagle_positions_ptr ,
321321 target_input_ids_ptr ,
322322 target_positions_ptr ,
323+ idx_mapping_ptr ,
323324 last_sampled_ptr ,
324325 next_prefill_tokens_ptr ,
325326 num_sampled_ptr ,
@@ -328,6 +329,8 @@ def _prepare_eagle_inputs_kernel(
328329 BLOCK_SIZE : tl .constexpr ,
329330):
330331 batch_idx = tl .program_id (0 )
332+ req_state_idx = tl .load (idx_mapping_ptr + batch_idx )
333+
331334 query_start = tl .load (query_start_loc_ptr + batch_idx )
332335 query_end = tl .load (query_start_loc_ptr + batch_idx + 1 )
333336 query_len = query_end - query_start
@@ -338,11 +341,11 @@ def _prepare_eagle_inputs_kernel(
338341
339342 num_sampled = tl .load (num_sampled_ptr + batch_idx )
340343 if num_sampled > 0 :
341- next_token = tl .load (last_sampled_ptr + batch_idx ).to (tl .int32 )
344+ next_token = tl .load (last_sampled_ptr + req_state_idx ).to (tl .int32 )
342345 else :
343346 # Chunked prefilling.
344347 # Get the next prefill token.
345- next_token = tl .load (next_prefill_tokens_ptr + batch_idx )
348+ next_token = tl .load (next_prefill_tokens_ptr + req_state_idx )
346349
347350 # Shift target_input_ids by one.
348351 for i in range (1 , query_len , BLOCK_SIZE ):
@@ -370,9 +373,9 @@ def prepare_eagle_inputs(
370373 num_sampled : torch .Tensor ,
371374 # [num_reqs]
372375 num_rejected : torch .Tensor ,
373- # [num_reqs ]
376+ # [max_num_reqs ]
374377 last_sampled : torch .Tensor ,
375- # [num_reqs ]
378+ # [max_num_reqs ]
376379 next_prefill_tokens : torch .Tensor ,
377380) -> torch .Tensor :
378381 num_reqs = input_batch .num_reqs
@@ -387,6 +390,7 @@ def prepare_eagle_inputs(
387390 input_buffers .positions ,
388391 input_batch .input_ids ,
389392 input_batch .positions ,
393+ input_batch .idx_mapping ,
390394 last_sampled ,
391395 next_prefill_tokens ,
392396 num_sampled ,
0 commit comments