Skip to content

Commit ef7ee6a

Browse files
authored
[None][feat] Add environment variable to force spec-dec number of accepted tokens (#9371)
Signed-off-by: Aurelien Chartier <2567591+achartier@users.noreply.github.com>
1 parent b10137f commit ef7ee6a

File tree

4 files changed

+60
-11
lines changed

4 files changed

+60
-11
lines changed

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
from tensorrt_llm.sampling_params import SamplingParams
5757

5858
from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE
59+
from ..speculative.interface import get_force_num_accepted_tokens
5960
from ..speculative.spec_tree_manager import SpecTreeManager
6061
from .finish_reason import FinishedState
6162
from .llm_request import LlmRequest, LlmRequestState, get_draft_token_length
@@ -662,6 +663,9 @@ def __init__(self, args: Args):
662663
self._global_seed = 42
663664
self._generator = None
664665

666+
# Force number of accepted tokens for speculative decoding testing
667+
self._force_num_accepted_tokens = get_force_num_accepted_tokens()
668+
665669
def get_generator(self, device: torch.device) -> torch.Generator:
666670
"""Get a deterministic generator for the specified device.
667671
@@ -784,15 +788,24 @@ def _process_draft_tokens_greedy(
784788
return 0
785789
num_accepted = 0
786790

787-
for draft_token in request.py_draft_tokens:
788-
if draft_token != new_token:
789-
# Reject.
790-
break
791-
792-
num_accepted += 1
793-
new_token = add_token(request, new_tokens, beam=BEAM, step=num_accepted)
794-
if self.finish_if_reason(request, finish_reasons, step=num_accepted):
795-
break
791+
if self._force_num_accepted_tokens != 0:
792+
# Force acceptance of up to force_num_accepted_tokens draft tokens
793+
force_limit = min(self._force_num_accepted_tokens, len(request.py_draft_tokens))
794+
for _ in request.py_draft_tokens[:force_limit]:
795+
num_accepted += 1
796+
new_token = add_token(request, new_tokens, beam=BEAM, step=num_accepted)
797+
if self.finish_if_reason(request, finish_reasons, step=num_accepted):
798+
break
799+
else:
800+
for draft_token in request.py_draft_tokens:
801+
if draft_token != new_token:
802+
# Reject.
803+
break
804+
805+
num_accepted += 1
806+
new_token = add_token(request, new_tokens, beam=BEAM, step=num_accepted)
807+
if self.finish_if_reason(request, finish_reasons, step=num_accepted):
808+
break
796809
return num_accepted
797810

798811
def _process_draft_tokens_tree(

tensorrt_llm/_torch/speculative/eagle3.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from ..pyexecutor.resource_manager import BaseResourceManager, SlotManager
1313
from ..pyexecutor.sampler import TorchSampler
1414
from ..pyexecutor.scheduler import ScheduledRequests
15-
from .interface import SpecMetadata
15+
from .interface import SpecMetadata, get_force_num_accepted_tokens
1616
from .mtp import MTPSampler
1717
from .spec_tree_manager import SpecTreeManager
1818

@@ -365,6 +365,7 @@ def __init__(self, spec_config: "EagleDecodingConfig", mapping: Mapping):
365365
self.max_draft_len = self.spec_config.max_draft_len
366366
self.mapping = mapping
367367
self.guided_decoder: Optional[CapturableGuidedDecoder] = None
368+
self.force_num_accepted_tokens = get_force_num_accepted_tokens()
368369

369370
# Skip torch.compile for now since current Torch is not compatible with Triton 3.4
370371
# @torch.compile(options={"max-autotune": True})
@@ -527,6 +528,11 @@ def sample_and_accept_draft_tokens(
527528
num_accepted_tokens[num_contexts:] += torch.cumprod(
528529
(draft_tokens == gen_target_tokens[:, :self.max_draft_len]).int(),
529530
dim=-1).sum(1)
531+
# Check for environment variable override
532+
if self.force_num_accepted_tokens != 0:
533+
force_num_accepted_tokens = min(self.force_num_accepted_tokens,
534+
self.max_draft_len + 1)
535+
num_accepted_tokens[num_contexts:] = force_num_accepted_tokens
530536
return accepted_tokens, num_accepted_tokens
531537

532538
def draft_decoder(

tensorrt_llm/_torch/speculative/interface.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,37 @@
11
import copy
2+
import os
23
from dataclasses import dataclass, field
34
from enum import IntEnum, auto
45
from typing import List, Optional, Type
56

67
import torch
78

9+
from tensorrt_llm.logger import logger
10+
811
from ..._utils import get_sm_version
912
from ..attention_backend.trtllm import AttentionBackend, TrtllmAttention
1013
from ..pyexecutor.resource_manager import BaseResourceManager
1114

15+
# Environment variable name for forcing the number of accepted tokens in speculative decoding
16+
FORCE_NUM_ACCEPTED_TOKENS_ENV_VAR = "TLLM_SPEC_DECODE_FORCE_NUM_ACCEPTED_TOKENS"
17+
18+
19+
def get_force_num_accepted_tokens() -> int:
20+
"""
21+
Read and parse the TLLM_SPEC_DECODE_FORCE_NUM_ACCEPTED_TOKENS environment variable.
22+
23+
Returns:
24+
int: The forced number of accepted tokens, or 0 if not set or invalid.
25+
"""
26+
env_value = os.environ.get(FORCE_NUM_ACCEPTED_TOKENS_ENV_VAR, "0")
27+
try:
28+
return int(env_value)
29+
except ValueError:
30+
logger.warning(
31+
f"{FORCE_NUM_ACCEPTED_TOKENS_ENV_VAR} must be a valid integer, "
32+
f"got '{env_value}'. Using default value 0.")
33+
return 0
34+
1235

1336
class SpeculativeDecodingMode(IntEnum):
1437
MTP = auto()

tensorrt_llm/_torch/speculative/mtp.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
SampleStateTensors, TorchSampler, add_token,
1818
int_tensor)
1919
from ..pyexecutor.scheduler import ScheduledRequests
20-
from .interface import SpecMetadata
20+
from .interface import SpecMetadata, get_force_num_accepted_tokens
2121

2222
if TYPE_CHECKING:
2323
from tensorrt_llm.llmapi.llm_args import MTPDecodingConfig
@@ -347,6 +347,7 @@ def __init__(self, spec_config: "MTPDecodingConfig", model_config=None):
347347
self.model_config = model_config
348348
self.is_thop = False
349349
self.guided_decoder: Optional[CapturableGuidedDecoder] = None
350+
self.force_num_accepted_tokens = get_force_num_accepted_tokens()
350351

351352
def forward(
352353
self,
@@ -895,6 +896,12 @@ def sample_and_accept_draft_tokens(
895896
).int(),
896897
dim=-1).sum(1)
897898

899+
# Check for environment variable override
900+
if self.force_num_accepted_tokens != 0:
901+
force_num_accepted_tokens = min(self.force_num_accepted_tokens,
902+
mtp_num_modules + 1)
903+
num_accepted_tokens[num_contexts:] = force_num_accepted_tokens
904+
898905
return accepted_tokens, num_accepted_tokens
899906

900907
def change_attn_metadata(self, num_accepted_tokens: torch.Tensor,

0 commit comments

Comments
 (0)