Skip to content

Commit 361bf82

Browse files
committed
debugged pytest deadlock. Now can pass stage 1 pytest.
Signed-off-by: Zheyu Fu <zheyuf@NVIDIA.com>
1 parent 7b0bf55 commit 361bf82

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

tensorrt_llm/_torch/speculative/drafter.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from abc import ABC, abstractmethod
2+
from bisect import bisect_left
23
from typing import Dict, List, Optional, final
34

45
from tensorrt_llm.logger import logger
@@ -67,16 +68,20 @@ def should_use_spec_decode(self, requests: List[LlmRequest],
6768
def pad_draft_tokens_for_cuda_graph(
6869
self, scheduled_requests: ScheduledRequests) -> None:
6970
"""
70-
Pad draft tokens to the static max total draft tokens for CUDA graph compatibility.
71+
Pad draft tokens to max total draft tokens for CUDA graph compatibility.
72+
73+
When draft_len_schedule is used, pads to the current dynamic max_total_draft_tokens.
74+
Otherwise, pads to the static max for consistent CUDA graph tensor sizes.
7175
7276
Args:
7377
scheduled_requests: The scheduled requests to pad
7478
"""
79+
# Use dynamic max when draft_len_schedule is active, otherwise use static max
80+
target_draft_len = self.max_total_draft_tokens if self.draft_len_schedule is not None else self._static_max_total_draft_tokens
7581
for req in scheduled_requests.generation_requests:
7682
num_draft_tokens = get_draft_token_length(req)
7783
req.py_draft_tokens.extend(
78-
0 for _ in range(self._static_max_total_draft_tokens -
79-
num_draft_tokens))
84+
0 for _ in range(target_draft_len - num_draft_tokens))
8085

8186
def get_draft_len_for_batch_size(self, runtime_batch_size: int) -> int:
8287
"""

tests/unittest/_torch/speculative/test_draft_len_schedule.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,8 +269,8 @@ def instrumented_prepare_draft(scheduled_batch, resource_manager):
269269
# This matches what the authentic code does: len(executor.active_requests)
270270
expected_mapping = {
271271
1: 5,
272-
2: 5,
273-
3: 5,
272+
2: 4,
273+
3: 4,
274274
4: 4,
275275
5: 3,
276276
6: 2,

0 commit comments

Comments
 (0)