|
1 | 1 | from abc import ABC, abstractmethod |
| 2 | +from bisect import bisect_left |
2 | 3 | from typing import Dict, List, Optional, final |
3 | 4 |
|
4 | 5 | from tensorrt_llm.logger import logger |
@@ -67,16 +68,20 @@ def should_use_spec_decode(self, requests: List[LlmRequest], |
67 | 68 | def pad_draft_tokens_for_cuda_graph( |
68 | 69 | self, scheduled_requests: ScheduledRequests) -> None: |
69 | 70 | """ |
70 | | - Pad draft tokens to the static max total draft tokens for CUDA graph compatibility. |
| 71 | + Pad draft tokens to max total draft tokens for CUDA graph compatibility. |
| 72 | +
|
| 73 | + When draft_len_schedule is used, pads to the current dynamic max_total_draft_tokens. |
| 74 | + Otherwise, pads to the static max for consistent CUDA graph tensor sizes. |
71 | 75 |
|
72 | 76 | Args: |
73 | 77 | scheduled_requests: The scheduled requests to pad |
74 | 78 | """ |
| 79 | + # Use dynamic max when draft_len_schedule is active, otherwise use static max |
| 80 | + target_draft_len = self.max_total_draft_tokens if self.draft_len_schedule is not None else self._static_max_total_draft_tokens |
75 | 81 | for req in scheduled_requests.generation_requests: |
76 | 82 | num_draft_tokens = get_draft_token_length(req) |
77 | 83 | req.py_draft_tokens.extend( |
78 | | - 0 for _ in range(self._static_max_total_draft_tokens - |
79 | | - num_draft_tokens)) |
| 84 | + 0 for _ in range(target_draft_len - num_draft_tokens)) |
80 | 85 |
|
81 | 86 | def get_draft_len_for_batch_size(self, runtime_batch_size: int) -> int: |
82 | 87 | """ |
|
0 commit comments