Skip to content

Commit d8b5aeb

Browse files
authored
[https://nvbugs/5652062][fix] Rewind kv_cache and reset draft tokens (#10160)
Signed-off-by: ziyixiong-nv <[email protected]>
1 parent 46e4af5 commit d8b5aeb

File tree

3 files changed

+45
-48
lines changed

3 files changed

+45
-48
lines changed

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 37 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1567,6 +1567,24 @@ def _executor_loop_overlap(self):
15671567
# For generation requests which have completed KV cache transfer
15681568
self._prepare_disagg_gen_transmission_complete(
15691569
scheduled_batch)
1570+
1571+
has_draft_batch = self.drafter is not None and self.previous_batch is not None and self.use_spec_decode and self.drafter.should_forward_draft_model(
1572+
scheduled_batch)
1573+
# Reset the draft tokens to avoid preparing resources for the draft model.
1574+
if self.drafter is not None and self.use_spec_decode and not has_draft_batch:
1575+
self.use_spec_decode = False
1576+
# We are not running the draft model. Remove the draft tokens and turn off spec
1577+
# decode so that the requests get handled correctly.
1578+
# One corner case: when we have at least one context request, we have to keep spec
1579+
# dec on. This ensures that we capture hidden states for requests that haven't done
1580+
# prefill yet.
1581+
self.use_spec_decode = False
1582+
self.model_engine.enable_spec_decode = len(
1583+
scheduled_batch.context_requests) > 0
1584+
if not self.model_engine.enable_spec_decode:
1585+
for request in scheduled_batch.all_requests():
1586+
request.py_draft_tokens = []
1587+
15701588
self.resource_manager.prepare_resources(scheduled_batch)
15711589

15721590
self._kv_connector_start_batch(scheduled_batch)
@@ -1602,8 +1620,11 @@ def _executor_loop_overlap(self):
16021620
# so we'll set the target model's input to None and skip updating the target requests after target model forward.
16031621
use_previous_draft_tokens = self.has_previous_draft_tokens
16041622
num_accepted_tokens_device = None
1605-
if self.drafter is not None and (self.use_spec_decode or
1606-
use_previous_draft_tokens):
1623+
1624+
target_inputs = None
1625+
num_accepted_tokens_device = None
1626+
1627+
if has_draft_batch:
16071628
target_inputs, num_accepted_tokens_device = self._handle_speculative_decoding(
16081629
scheduled_batch, previous_tensors,
16091630
previous_tensors_device)
@@ -2746,44 +2767,20 @@ def _handle_speculative_decoding(
27462767
) -> Tuple[Optional[SampleStateTensorsMTP], Optional[torch.Tensor]]:
27472768
with request_context(is_draft=self.draft_model_engine is not None,
27482769
scheduled_requests=scheduled_batch):
2749-
# Do an early checking to see if we need to forward the draft model.
2750-
# If needed, the overlap should happen between the target requests and the draft requests.
2751-
# Otherwise, we can still do overlap between the previous target requests and the current target requests.
2752-
has_draft_batch = (
2753-
self.previous_batch is not None and self.use_spec_decode
2754-
and self.drafter.should_forward_draft_model(scheduled_batch))
2755-
2756-
new_target_inputs = None
2757-
num_accepted_tokens_device = None
2758-
if has_draft_batch:
2759-
target_outputs = self.previous_batch.sample_state and self.previous_batch.sample_state.device
2760-
assert target_outputs is not None, "target_outputs should not be None"
2761-
new_target_inputs, num_accepted_tokens_device = self._accept_draft_tokens(
2762-
scheduled_batch=scheduled_batch,
2763-
target_inputs=target_inputs,
2764-
target_outputs=target_outputs)
2765-
2766-
if has_draft_batch:
2767-
self.drafter.generate_draft_tokens_with_overlap(
2768-
scheduled_batch, self.resource_manager,
2769-
previous_tensors.device if previous_tensors else None,
2770-
new_target_inputs, num_accepted_tokens_device)
2771-
2772-
# Pad draft tokens to the max draft length for CUDA graph compatibility
2773-
self.has_previous_draft_tokens = new_target_inputs is not None and new_target_inputs.next_draft_tokens is not None
2774-
else:
2775-
self.has_previous_draft_tokens = False
2776-
# We are not running the draft model. Remove the draft tokens and turn off spec
2777-
# decode so that the requests get handled correctly.
2778-
# One corner case: when we have at least one context request, we have to keep spec
2779-
# dec on. This ensures that we capture hidden states for requests that haven't done
2780-
# prefill yet.
2781-
self.use_spec_decode = False
2782-
self.model_engine.enable_spec_decode = len(
2783-
scheduled_batch.context_requests) > 0
2784-
if not self.model_engine.enable_spec_decode:
2785-
for request in scheduled_batch.all_requests():
2786-
request.py_draft_tokens = []
2770+
target_outputs = self.previous_batch.sample_state and self.previous_batch.sample_state.device
2771+
assert target_outputs is not None, "target_outputs should not be None"
2772+
new_target_inputs, num_accepted_tokens_device = self._accept_draft_tokens(
2773+
scheduled_batch=scheduled_batch,
2774+
target_inputs=target_inputs,
2775+
target_outputs=target_outputs)
2776+
2777+
self.drafter.generate_draft_tokens_with_overlap(
2778+
scheduled_batch, self.resource_manager,
2779+
previous_tensors.device if previous_tensors else None,
2780+
new_target_inputs, num_accepted_tokens_device)
2781+
2782+
# Pad draft tokens to the max draft length for CUDA graph compatibility
2783+
self.has_previous_draft_tokens = new_target_inputs is not None and new_target_inputs.next_draft_tokens is not None
27872784

27882785
return new_target_inputs, num_accepted_tokens_device
27892786

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -599,11 +599,11 @@ def update_resources(self,
599599
self.update_kv_cache_draft_token_location(scheduled_batch,
600600
attn_metadata,
601601
kv_cache_dtype_byte_size)
602-
# rewind kv cache
603-
for request in scheduled_batch.generation_requests:
604-
if request.state != LlmRequestState.GENERATION_COMPLETE:
605-
if request.py_rewind_len > 0:
606-
self.rewind_kv_cache(request, request.py_rewind_len)
602+
# rewind kv cache
603+
for request in scheduled_batch.generation_requests:
604+
if request.state != LlmRequestState.GENERATION_COMPLETE:
605+
if request.py_rewind_len > 0:
606+
self.rewind_kv_cache(request, request.py_rewind_len)
607607

608608
# For context requests, we store the blocks for reuse.
609609
for request in scheduled_batch.context_requests:

tests/unittest/_torch/speculative/test_eagle3.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -588,7 +588,7 @@ def test_eagle3_cuda_graph_padding(disable_overlap_scheduler: bool):
588588
max_batch_size = 4
589589
max_draft_len = 4
590590
kv_cache_config = KvCacheConfig(enable_block_reuse=enable_block_reuse,
591-
max_tokens=8192)
591+
max_tokens=4096)
592592
cuda_graph_config = CudaGraphConfig(batch_sizes=[1, 2, 4],
593593
enable_padding=True)
594594

@@ -599,7 +599,7 @@ def test_eagle3_cuda_graph_padding(disable_overlap_scheduler: bool):
599599
cuda_graph_config=cuda_graph_config,
600600
max_batch_size=max_batch_size,
601601
kv_cache_config=kv_cache_config,
602-
max_seq_len=8192,
602+
max_seq_len=2048,
603603
enable_chunked_prefill=enable_chunked_prefill,
604604
)
605605

@@ -617,7 +617,7 @@ def test_eagle3_cuda_graph_padding(disable_overlap_scheduler: bool):
617617
"The future of AI is"
618618
]
619619

620-
sampling_params = SamplingParams(max_tokens=20, temperature=0)
620+
sampling_params = SamplingParams(max_tokens=2048, temperature=0)
621621
llm_spec.generate(prompts, sampling_params)
622622
llm_spec.shutdown()
623623

0 commit comments

Comments
 (0)