Skip to content

Commit 963dc0b

Browse files
authored
[Model Runner V2] Minor optimization for eagle input processing (vllm-project#32535)
Signed-off-by: Woosuk Kwon <[email protected]>
1 parent 8cc26ac commit 963dc0b

File tree

2 files changed

+12
-14
lines changed

2 files changed

+12
-14
lines changed

vllm/v1/worker/gpu/model_runner.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -827,20 +827,14 @@ def propose_draft(
827827
num_rejected: torch.Tensor,
828828
) -> torch.Tensor:
829829
assert self.speculator is not None
830-
last_sampled_tokens = self.req_states.last_sampled_tokens[
831-
input_batch.idx_mapping
832-
]
833-
next_prefill_tokens = self.req_states.next_prefill_tokens[
834-
input_batch.idx_mapping
835-
]
836830
draft_tokens = self.speculator.propose(
837831
input_batch,
838832
last_hidden_states,
839833
aux_hidden_states,
840834
num_sampled,
841835
num_rejected,
842-
last_sampled_tokens,
843-
next_prefill_tokens,
836+
self.req_states.last_sampled_tokens,
837+
self.req_states.next_prefill_tokens,
844838
self.sampler.sampling_states.temperature.gpu,
845839
self.sampler.sampling_states.seeds.gpu,
846840
)

vllm/v1/worker/gpu/spec_decode/eagle.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -195,9 +195,9 @@ def propose(
195195
num_sampled: torch.Tensor,
196196
# [num_reqs]
197197
num_rejected: torch.Tensor,
198-
# [num_reqs]
198+
# [max_num_reqs]
199199
last_sampled: torch.Tensor,
200-
# [num_reqs]
200+
# [max_num_reqs]
201201
next_prefill_tokens: torch.Tensor,
202202
# [max_num_reqs]
203203
temperature: torch.Tensor,
@@ -320,6 +320,7 @@ def _prepare_eagle_inputs_kernel(
320320
eagle_positions_ptr,
321321
target_input_ids_ptr,
322322
target_positions_ptr,
323+
idx_mapping_ptr,
323324
last_sampled_ptr,
324325
next_prefill_tokens_ptr,
325326
num_sampled_ptr,
@@ -328,6 +329,8 @@ def _prepare_eagle_inputs_kernel(
328329
BLOCK_SIZE: tl.constexpr,
329330
):
330331
batch_idx = tl.program_id(0)
332+
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
333+
331334
query_start = tl.load(query_start_loc_ptr + batch_idx)
332335
query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
333336
query_len = query_end - query_start
@@ -338,11 +341,11 @@ def _prepare_eagle_inputs_kernel(
338341

339342
num_sampled = tl.load(num_sampled_ptr + batch_idx)
340343
if num_sampled > 0:
341-
next_token = tl.load(last_sampled_ptr + batch_idx).to(tl.int32)
344+
next_token = tl.load(last_sampled_ptr + req_state_idx).to(tl.int32)
342345
else:
343346
# Chunked prefilling.
344347
# Get the next prefill token.
345-
next_token = tl.load(next_prefill_tokens_ptr + batch_idx)
348+
next_token = tl.load(next_prefill_tokens_ptr + req_state_idx)
346349

347350
# Shift target_input_ids by one.
348351
for i in range(1, query_len, BLOCK_SIZE):
@@ -370,9 +373,9 @@ def prepare_eagle_inputs(
370373
num_sampled: torch.Tensor,
371374
# [num_reqs]
372375
num_rejected: torch.Tensor,
373-
# [num_reqs]
376+
# [max_num_reqs]
374377
last_sampled: torch.Tensor,
375-
# [num_reqs]
378+
# [max_num_reqs]
376379
next_prefill_tokens: torch.Tensor,
377380
) -> torch.Tensor:
378381
num_reqs = input_batch.num_reqs
@@ -387,6 +390,7 @@ def prepare_eagle_inputs(
387390
input_buffers.positions,
388391
input_batch.input_ids,
389392
input_batch.positions,
393+
input_batch.idx_mapping,
390394
last_sampled,
391395
next_prefill_tokens,
392396
num_sampled,

0 commit comments

Comments
 (0)