Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,13 +1039,27 @@ def _prepare_and_schedule_batch(self):

if self.drafter is not None:
# Honor permanent disable flag based on rolling acceptance first
if getattr(self, 'speculation_permanently_disabled', False):
if self.drafter.draft_len_schedule is not None:
batch_size_input = len(self.active_requests)

self.max_total_draft_tokens = self.drafter.get_draft_len_for_batch_size(
batch_size_input)

self.drafter.update_max_total_draft_tokens(
self.max_total_draft_tokens)

# Check if draft_len=0 → immediately disable
# self.max_total_draft_tokens==0 is only possible when draft_len_schedule is provided
# for example, draft_len_schedule = {1:4, 4:2, 8:0}, batch_size >= 8 will set self.max_draft_len = 0
if self.drafter.draft_len_schedule is not None and self.max_total_draft_tokens == 0:
self.use_spec_decode = False
elif getattr(self, 'speculation_permanently_disabled', False):
self.use_spec_decode = False
else:
self.use_spec_decode = self.drafter.should_use_spec_decode(
self.active_requests, self.max_batch_size,
self.model_engine.max_num_tokens,
self.model_engine.spec_config.max_total_draft_tokens)
self.max_total_draft_tokens)
logger.debug(f"Use spec decode: {self.use_spec_decode}")
self.model_engine.enable_spec_decode = self.use_spec_decode

Expand All @@ -1055,10 +1069,9 @@ def _prepare_and_schedule_batch(self):
LlmRequestState.GENERATION_IN_PROGRESS,
LlmRequestState.DISAGG_GENERATION_INIT):
continue
max_total_draft_tokens = self.model_engine.spec_config.max_total_draft_tokens
request.draft_tokens = [
0
] * max_total_draft_tokens if max_total_draft_tokens > 0 else []
] * self.max_total_draft_tokens if self.max_total_draft_tokens > 0 else []

# If speculation is off, this function sets py_draft_tokens to []
# for all active requests. If it's on, we initialize py_draft_tokens
Expand Down Expand Up @@ -1250,11 +1263,10 @@ def _prepare_draft_requests(self):
continue

req.py_last_draft_tokens = req.py_draft_tokens
max_total_draft_tokens = self.model_engine.spec_config.max_total_draft_tokens

if max_total_draft_tokens > 0 and self.use_spec_decode and not req.py_disable_speculative_decoding:
req.py_draft_tokens = [0] * max_total_draft_tokens
req.py_draft_pages_allocated = max_total_draft_tokens
if self.max_total_draft_tokens > 0 and self.use_spec_decode and not req.py_disable_speculative_decoding:
req.py_draft_tokens = [0] * self.max_total_draft_tokens
req.py_draft_pages_allocated = self.max_total_draft_tokens
else:
req.py_draft_tokens = []
req.py_draft_pages_allocated = 0
Expand Down
10 changes: 6 additions & 4 deletions tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,10 +350,12 @@ def allocation_scope(current_stage: ExecutorMemoryType,
RestoreMode.PINNED):
draft_spec_config = copy.copy(spec_config)

use_chain_drafter = (guided_decoding_config is None
and draft_spec_config._allow_chain_drafter and
draft_spec_config._allow_greedy_draft_tokens
and llm_args.attn_backend == "TRTLLM")
use_chain_drafter = (
guided_decoding_config is None
and draft_spec_config._allow_chain_drafter
and draft_spec_config._allow_greedy_draft_tokens
and llm_args.attn_backend == "TRTLLM"
and draft_spec_config.draft_len_schedule is None)

logger.debug(f"USE CHAIN DRAFTER: {use_chain_drafter}")
if use_chain_drafter:
Expand Down
3 changes: 2 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1097,7 +1097,8 @@ def update_requests(
)
if get_draft_token_length(req) > 0:
req.py_num_accepted_draft_tokens = num_accepted
req.py_rewind_len = req.py_draft_pages_allocated - num_accepted
actual_draft_len = get_draft_token_length(req)
req.py_rewind_len = actual_draft_len - num_accepted
else:
req.py_num_accepted_draft_tokens = 0
req.py_rewind_len = 0
Expand Down
63 changes: 58 additions & 5 deletions tensorrt_llm/_torch/speculative/drafter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from abc import ABC, abstractmethod
from typing import List, Optional, final
from bisect import bisect_right
from typing import Dict, List, Optional, final

from tensorrt_llm.logger import logger

from ..pyexecutor.llm_request import LlmRequest, get_draft_token_length
from ..pyexecutor.resource_manager import ResourceManager
Expand All @@ -9,8 +12,16 @@
class Drafter(ABC):
"""Abstract base class for all drafter implementations."""

def __init__(self, max_concurrency: Optional[int] = None) -> None:
def __init__(self,
max_draft_len: int = None,
max_total_draft_tokens: int = None,
max_concurrency: Optional[int] = None,
draft_len_schedule: Optional[Dict[int, int]] = None) -> None:
self.max_draft_len = max_draft_len
self.max_total_draft_tokens = max_total_draft_tokens
self._static_max_total_draft_tokens = max_total_draft_tokens
self.max_concurrency = max_concurrency
self.draft_len_schedule = draft_len_schedule

@abstractmethod
def prepare_draft_tokens(
Expand Down Expand Up @@ -57,16 +68,58 @@ def should_use_spec_decode(self, requests: List[LlmRequest],
def pad_draft_tokens_for_cuda_graph(
self, scheduled_requests: ScheduledRequests) -> None:
"""
Pad draft tokens to the max draft length for CUDA graph compatibility.
Pad draft tokens to the static max total draft tokens for CUDA graph compatibility.

Args:
scheduled_requests: The scheduled requests to pad
"""
for req in scheduled_requests.generation_requests:
max_draft_tokens = self.max_draft_len
num_draft_tokens = get_draft_token_length(req)
req.py_draft_tokens.extend(
0 for _ in range(max_draft_tokens - num_draft_tokens))
0 for _ in range(self._static_max_total_draft_tokens -
num_draft_tokens))

def get_draft_len_for_batch_size(self, batch_size: int) -> int:
"""
Get the appropriate draft length for the given batch size using binary search.
Args:
batch_size: Current batch size (has been sorted by config validator)
Returns:
The draft length to use for this batch size
"""

# Binary search to find the largest threshold <= batch_size
# draft_len_schedule is already sorted by config validator
thresholds = list(self.draft_len_schedule.keys())

# bisect_right finds where to insert batch_size to keep list sorted
# The element before insertion point is the largest threshold <= batch_size
idx = bisect_right(thresholds, batch_size)

if idx == 0:
# batch_size is smaller than smallest threshold (batch_size smaller than 1)
# This shouldn't happen in practice, but handle defensively
logger.warning(
f"get_draft_len_for_batch_size called with batch_size={batch_size} < 1. "
f"This is unexpected. Disabling speculation (returning draft_len=0)."
)
return 0

# Return draft_len for the largest threshold <= batch_size
threshold = thresholds[idx - 1]
return self.draft_len_schedule[threshold]

def update_max_total_draft_tokens(self,
new_max_total_draft_tokens: int) -> None:
"""
Used when draft_len_schedule is provided in spec_config (dynamic draft length based on runtime batch size is enabled)
Update max_total_draft_tokens in drafter and propagate to any dependent components.
Subclasses can override to propagate to their resource managers if needed.
Args:
new_max_total_draft_tokens: The new max total draft tokens
"""
self.max_total_draft_tokens = new_max_total_draft_tokens
self.max_draft_len = new_max_total_draft_tokens

def run_drafter_post(
self,
Expand Down
11 changes: 7 additions & 4 deletions tensorrt_llm/_torch/speculative/model_drafter.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,6 @@ def __init__(
spec_resource_manager: Optional[BaseResourceManager] = None,
guided_decoder: Optional[GuidedDecoder] = None,
):
super().__init__(spec_config.max_concurrency)

# Validate required parameters
if draft_model_engine is None:
raise ValueError("draft_model_engine cannot be None")
Expand All @@ -74,6 +72,11 @@ def __init__(
raise ValueError("max_total_draft_tokens must be >= 0")
assert max_draft_len <= max_total_draft_tokens

super().__init__(max_draft_len=max_draft_len,
max_total_draft_tokens=max_total_draft_tokens,
max_concurrency=spec_config.max_concurrency,
draft_len_schedule=spec_config.draft_len_schedule)

# Model and resource management
self.draft_model_engine = draft_model_engine
self.draft_seq_slot_manager = draft_seq_slot_manager
Expand All @@ -82,8 +85,7 @@ def __init__(

# Configuration
self.spec_config = spec_config
self.max_draft_len = max_draft_len
self.max_total_draft_tokens = max_total_draft_tokens

# Sampling
self.sampler = sampler
self.guided_decoder = guided_decoder
Expand All @@ -93,6 +95,7 @@ def __init__(
# TODO: enable sampling/guided decoding on static draft loop
assert guided_decoder is None
assert spec_config._allow_greedy_draft_tokens
assert spec_config.draft_len_schedule is None

# Create accumulator for draft tokens in non-CDL mode
self.draft_tokens_accumulator: Dict[int, List[int]] = {}
Expand Down
21 changes: 15 additions & 6 deletions tensorrt_llm/_torch/speculative/ngram.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class NGramPoolManager(BaseResourceManager):
`matches` is a list of candidate draft token ids attaching to a pattern.

Arguments:
max_draft_tokens: int
max_total_draft_tokens: int
The length maximum of draft tokens (can be understood as length maximum of output draft tokens).

max_matching_ngram_size: int
Expand All @@ -51,7 +51,7 @@ class NGramPoolManager(BaseResourceManager):

def __init__(self, spec_config: "NGramDecodingConfig",
max_num_requests: int):
self.max_draft_tokens = spec_config.max_draft_len
self.max_total_draft_tokens = spec_config.max_total_draft_tokens
self.max_matching_ngram_size = spec_config.max_matching_ngram_size
self.is_keep_all = spec_config.is_keep_all
self.is_use_oldest = spec_config.is_use_oldest # TODO: remove this if updating strategy is supported
Expand Down Expand Up @@ -107,7 +107,7 @@ def get_draft_tokens(
-1):
# Find each possible pattern-match combination, and use tuple for hash
for l in range(len(sequence) - size):
r = min(l + size + self.max_draft_tokens, len(sequence))
r = min(l + size + self.max_total_draft_tokens, len(sequence))
pattern = tuple(sequence[l:l + size])
new_match = tuple(sequence[l + size:r])
if pattern not in pool or \
Expand Down Expand Up @@ -138,7 +138,7 @@ def get_draft_tokens(
# Update start_index
self.start_index[request_id] = max(
0, prefix_len -
(self.max_draft_tokens + self.max_matching_ngram_size - 1))
(self.max_total_draft_tokens + self.max_matching_ngram_size - 1))

return draft_tokens

Expand Down Expand Up @@ -167,10 +167,13 @@ def __init__(
spec_config: NGramDecodingConfig,
ngram_pool_manager: NGramPoolManager = None,
):
super().__init__(spec_config.max_concurrency)
super().__init__(
max_draft_len=spec_config.max_draft_len,
max_total_draft_tokens=spec_config.max_total_draft_tokens,
max_concurrency=spec_config.max_concurrency,
draft_len_schedule=spec_config.draft_len_schedule)
assert ngram_pool_manager is not None, "NGram needs a resource manager to maintain the pool."
self.spec_config = spec_config
self.max_draft_len = spec_config.max_draft_len
self.spec_resource_manager = ngram_pool_manager

def prepare_draft_tokens(
Expand All @@ -197,3 +200,9 @@ def prepare_draft_tokens(
request.py_max_new_tokens,
)
request.py_draft_tokens = draft_tokens

def update_max_total_draft_tokens(self,
new_max_total_draft_tokens: int) -> None:
"""Override to propagate to NGramPoolManager."""
super().update_max_total_draft_tokens(new_max_total_draft_tokens)
self.spec_resource_manager.max_total_draft_tokens = new_max_total_draft_tokens
55 changes: 55 additions & 0 deletions tensorrt_llm/llmapi/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,16 @@ class DecodingBaseConfig(StrictBaseModel):
# this value. Otherwise, speculation will always be on.
max_concurrency: Optional[int] = None

# Developer interface: dynamically adjust draft length based on active batch size in runtime.
# Maps batch size to draft lengths. For example:
# {1: 4, 4: 2, 8: 0} means:
# - batch_size >= 1: use draft_len=4
# - batch_size >= 4: use draft_len=2
# - batch_size >= 8: use draft_len=0 (disable speculation)
# draft_len_schedule is enforced to contain batch_size=1 and its according draft_len equals max_draft_len for consistency
# for example, if max_draft_len=4, the schedule must contain {1: 4}
draft_len_schedule: Optional[dict[int, int]] = None

load_format: Optional[str] = None
# PyTorch only.
# Rolling average window size (N) for acceptance length across completed requests.
Expand Down Expand Up @@ -597,6 +607,51 @@ def _validate_acceptance_length_threshold(cls, v: Optional[float]):
# If set, drafting uses greedy sampling, irrespective of sampling parameters.
_allow_greedy_draft_tokens: bool = PrivateAttr(True)

@field_validator('draft_len_schedule')
@classmethod
def validate_draft_len_schedule_and_sort(cls, v, info):
"""Validate and sort draft_len_schedule by batch size thresholds."""
if v is not None:
# Validate values
for batch_size, draft_len in v.items():
if batch_size < 1:
raise ValueError(
f"draft_len_schedule: batch size threshold must be >= 1, got {batch_size}"
)
if draft_len < 0:
raise ValueError(
f"draft_len_schedule: draft length must be >= 0, got {draft_len}"
)

# Require batch_size=1 in schedule
if 1 not in v:
raise ValueError(
"draft_len_schedule must include batch_size=1. "
"All systems can have batch_size=1. Add {1: <max_draft_len>} to your schedule."
)

# Enforce schedule[1] == max_draft_len for consistency
max_draft_len = info.data.get('max_draft_len')
if max_draft_len is not None and v[1] != max_draft_len:
raise ValueError(
f"draft_len_schedule[1] must equal max_draft_len for consistency. "
f"Got schedule[1]={v[1]}, but max_draft_len={max_draft_len}. "
f"batch_size=1 should use maximum draft length.")

# Enforce all draft lengths <= max_draft_len
if max_draft_len is not None:
for batch_size, draft_len in v.items():
if draft_len > max_draft_len:
raise ValueError(
f"draft_len_schedule: all draft lengths must be <= max_draft_len. "
f"Got draft_len={draft_len} for batch_size={batch_size}, "
f"but max_draft_len={max_draft_len}.")

# Return sorted dict (by batch size thresholds)
# This ensures efficient lookup
return dict(sorted(v.items(), key=lambda x: x[0]))
return v

@classmethod
def from_dict(cls, data: dict):
# dispatch to the correct decoding config
Expand Down
Loading