Skip to content

Commit 42a70e0

Browse files
committed
Commit before rebase.
Signed-off-by: Zheyu Fu <zheyuf@NVIDIA.com>
1 parent 94b2930 commit 42a70e0

File tree

3 files changed

+90
-29
lines changed

3 files changed

+90
-29
lines changed

tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ def __init__(self, config: CUDAGraphRunnerConfig):
9999
Callable[[], Optional[torch.Tensor]]] = {}
100100
self.graph_metadata: Dict[Tuple[int, int, int], Dict[str, Any]] = {}
101101
self.memory_pool = config.cuda_graph_mem_pool
102-
self.padding_dummy_request: Optional["Request"] = None
102+
self.padding_dummy_requests: Dict[int, "Request"] = {
103+
} # draft_len -> dummy_request
103104

104105
self.shared_static_tensors: Dict[str, torch.Tensor] = {}
105106
if self.enabled:
@@ -166,6 +167,7 @@ def maybe_get_cuda_graph(
166167
batch: ScheduledRequests,
167168
iter_counter: int,
168169
enable_spec_decode: bool,
170+
runtime_draft_len: int,
169171
attn_metadata: Any,
170172
spec_metadata: Optional[Any] = None,
171173
draft_tokens_cuda: Optional[torch.Tensor] = None,
@@ -372,26 +374,31 @@ def _get_padded_batch(self, batch: ScheduledRequests,
372374
# No padding if it would create too many concurrent requests.
373375
# This is not strictly required, but we should probably
374376
# respect the requirement just in case that changes in the future.
375-
if self.padding_dummy_request is None:
377+
if runtime_draft_len not in self.padding_dummies:
376378
available_blocks = kv_cache_manager.get_num_free_blocks()
377379
# No padding if not enough KV cache space
378380
if available_blocks < 1:
379381
return 0
380-
381-
self.padding_dummy_request = kv_cache_manager.add_dummy_requests(
382-
[CUDA_GRAPH_DUMMY_REQUEST_ID],
382+
# Create dummy for this specific draft_len (happens once per unique draft_len)
383+
# Use unique request ID per draft_len to avoid conflicts
384+
dummy_req_id = CUDA_GRAPH_DUMMY_REQUEST_ID - runtime_draft_len
385+
dummy = kv_cache_manager.add_dummy_requests(
386+
[dummy_req_id],
383387
is_gen=True,
384388
max_num_draft_tokens=runtime_draft_len,
385389
use_mrope=self.config.use_mrope,
386390
max_beam_width=self.config.max_beam_width)[0]
387-
self.padding_dummy_request.is_cuda_graph_dummy = True
391+
dummy.is_cuda_graph_dummy = True
388392
spec_res_mgr = resource_manager.get_resource_manager(
389393
ResourceManagerType.SPEC_RESOURCE_MANAGER)
390394
if spec_res_mgr:
391-
spec_res_mgr.add_dummy_requests([CUDA_GRAPH_DUMMY_REQUEST_ID])
395+
spec_res_mgr.add_dummy_requests([dummy_req_id])
396+
397+
# Store for reuse
398+
self.padding_dummies[runtime_draft_len] = dummy
392399

393-
batch.generation_requests.extend([self.padding_dummy_request] *
394-
padding_size)
400+
padding_dummy = self.padding_dummies[runtime_draft_len]
401+
batch.generation_requests.extend([padding_dummy] * padding_size)
395402
return padding_size
396403

397404
def _round_up_batch_size(self, batch_size: int) -> int:
@@ -426,7 +433,7 @@ def clear(self):
426433
self.graphs.clear()
427434
self.graph_outputs.clear()
428435
self.graph_metadata.clear()
429-
self.padding_dummy_request = None
436+
self.padding_dummies.clear()
430437
del self.memory_pool
431438
self.memory_pool = None
432439
torch.cuda.empty_cache()

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 61 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,42 @@ def _run_cuda_graph_warmup(self, resource_manager: ResourceManager):
658658
self._capture_generation_cuda_graphs(resource_manager)
659659
self._capture_piecewise_cuda_graphs(resource_manager)
660660

661+
def _graphs_for_dynamic_draft_length(self):
662+
"""
663+
Compute the set of (batch_size, draft_len) pairs that are actually reachable.
664+
Used in dynamic draft length feature.
665+
"""
666+
graphs_to_capture = []
667+
schedule_thresholds = sorted(self.spec_config.draft_len_schedule.keys())
668+
669+
# Only iterate over actual CUDA graph batch sizes, not all possible batch sizes
670+
for graph_bs in self._cuda_graph_batch_sizes:
671+
idx = bisect.bisect_right(schedule_thresholds, graph_bs)
672+
if idx == 0:
673+
draft_len = 0 # Defensive
674+
else:
675+
draft_len = self.spec_config.draft_len_schedule[
676+
schedule_thresholds[idx - 1]]
677+
678+
graphs_to_capture.append((graph_bs, draft_len))
679+
680+
return list(
681+
set(graphs_to_capture)) # Use set to remove duplicates if any
682+
683+
# def _round_up_to_graph_size(self, actual_bs: int) -> int:
684+
# """Round up actual batch size to nearest CUDA graph batch size using binary search."""
685+
# if not self._cuda_graph_batch_sizes:
686+
# return 0
687+
688+
# idx = bisect.bisect_left(self._cuda_graph_batch_sizes, actual_bs)
689+
690+
# # If exact match or idx points to next larger size
691+
# if idx < len(self._cuda_graph_batch_sizes):
692+
# return self._cuda_graph_batch_sizes[idx]
693+
694+
# # actual_bs is larger than all available sizes
695+
# return self._cuda_graph_batch_sizes[-1]
696+
661697
def _capture_generation_cuda_graphs(self,
662698
resource_manager: ResourceManager):
663699
"""Captures CUDA graphs for pure generation steps."""
@@ -674,38 +710,48 @@ def _capture_generation_cuda_graphs(self,
674710
cuda_graph_batch_sizes = sorted(self._cuda_graph_batch_sizes,
675711
reverse=True)
676712
# Create CUDA graphs for different draft lengths
677-
draft_lengths = []
713+
# draft_lengths = []
678714
if self.is_draft_model:
679715
if self.model_is_wrapped and self.is_spec_decode and spec_resource_manager is not None and isinstance(
680716
spec_resource_manager, Eagle3ResourceManager):
681717
# The CDL path uses draft_len > 0 for the number of iterations in the drafting loop.
682-
draft_lengths.append(self.original_max_total_draft_tokens)
718+
draft_len = self.original_max_total_draft_tokens
683719
else:
684-
draft_lengths.append(self.max_total_draft_tokens)
720+
draft_len = self.max_total_draft_tokens
721+
graphs_to_capture = [(bs, draft_len)
722+
for bs in cuda_graph_batch_sizes]
723+
elif (self.spec_config
724+
and hasattr(self.spec_config, 'draft_len_schedule')
725+
and self.spec_config.draft_len_schedule is not None):
726+
# target model with draft_len_schedule: compute exact reachable set
727+
graphs_to_capture = self._graphs_for_dynamic_draft_length()
685728
else:
686729
# For non-draft model, we also capture the CUDA graph instance for draft length 0,
687730
# so that when we disable spec decode at runtime, we can still run the captured graph.
688731
# Note that for one engine mode, we are not able to turn off spec decode at runtime.
732+
graphs_to_capture = []
689733
if (self.max_total_draft_tokens > 0
690734
and not self.spec_config.spec_dec_mode.use_one_engine()
691735
# Assume that speculation is always on if the user didn't give us a max_concurrency
692736
# value. This will save on memory.
693737
and self.spec_config.max_concurrency is not None):
694-
draft_lengths.append(0)
695-
draft_lengths = [self.max_total_draft_tokens]
738+
graphs_to_capture.extend([(bs, 0)
739+
for bs in cuda_graph_batch_sizes])
740+
else:
741+
graphs_to_capture.extend([(bs, self.max_total_draft_tokens)
742+
for bs in cuda_graph_batch_sizes])
696743

697-
for bs in cuda_graph_batch_sizes:
744+
graphs_to_capture = sorted(graphs_to_capture, reverse=True)
745+
for bs, draft_len in graphs_to_capture:
698746
if bs > self.batch_size:
699747
continue
700-
701-
for draft_len in draft_lengths:
702-
warmup_request = self._create_cuda_graph_warmup_request(
703-
resource_manager, bs, draft_len)
704-
with self._release_batch_context(warmup_request,
705-
resource_manager) as batch:
706-
if batch is None:
707-
# No KV cache space, cannot continue capturing graphs
708-
return
748+
warmup_request = self._create_cuda_graph_warmup_request(
749+
resource_manager, bs, draft_len)
750+
with self._release_batch_context(warmup_request,
751+
resource_manager) as batch:
752+
if batch is None:
753+
# No KV cache space, cannot continue capturing graphs
754+
return
709755

710756
logger.info(
711757
f"Run generation-only CUDA graph warmup for batch size={bs}, draft_len={draft_len}"

tensorrt_llm/_torch/speculative/drafter.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,16 +68,24 @@ def should_use_spec_decode(self, requests: List[LlmRequest],
6868
def pad_draft_tokens_for_cuda_graph(
6969
self, scheduled_requests: ScheduledRequests) -> None:
7070
"""
71-
Pad draft tokens to the static max total draft tokens for CUDA graph compatibility.
71+
Pad draft tokens for CUDA graph compatibility.
72+
CUDA graphs require all requests in a batch to have the same tensor shape.
73+
Individual requests may generate fewer draft tokens (e.g., NGram mismatches,
74+
early stopping), but all must be padded to the same length.
7275
7376
Args:
7477
scheduled_requests: The scheduled requests to pad
7578
"""
7679
for req in scheduled_requests.generation_requests:
7780
num_draft_tokens = get_draft_token_length(req)
78-
req.py_draft_tokens.extend(
79-
0 for _ in range(self._static_max_total_draft_tokens -
80-
num_draft_tokens))
81+
if self.draft_len_schedule is not None:
82+
# Pad to current iteration's (dynamic) max_draft_tokens if dynamic draft length is enabled
83+
target_len = self.max_total_draft_tokens
84+
else:
85+
target_len = self._static_max_total_draft_tokens
86+
if num_draft_tokens < target_len:
87+
req.py_draft_tokens.extend(
88+
0 for _ in range(target_len - num_draft_tokens))
8189

8290
def get_draft_len_for_batch_size(self, batch_size: int) -> int:
8391
"""

0 commit comments

Comments
 (0)