Skip to content

Commit 4f80961

Browse files
zheyufdominicshanshan
authored andcommitted
[TRTLLM-7412][feat] Turn off spec decode when the rolling average acceptance length drops below threshold. (NVIDIA#7283)
Signed-off-by: Zheyu Fu <[email protected]>
1 parent ff19043 commit 4f80961

File tree

5 files changed

+298
-5
lines changed

5 files changed

+298
-5
lines changed

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from ..models.modeling_utils import DecoderModelForCausalLM
4141
from ..modules.decoder_layer import DecoderLayer
4242
from ..speculative.drafter import Drafter
43+
from ..speculative.speculation_gate import SpeculationGate
4344
from .executor_request_queue import ExecutorRequestQueue, RequestQueueItem
4445
from .guided_decoder import GuidedDecoder
4546
from .handle_additional_outputs import HandleAdditionalOutputs
@@ -211,6 +212,20 @@ def __init__(self,
211212
self.num_fetch_requests = 0
212213
self.shutdown_event = threading.Event()
213214

215+
# Rolling acceptance tracking for spec decode (disable speculation if rolling acceptance is below threshold)
216+
spec_config = getattr(self.model_engine, 'spec_config', None)
217+
self.acceptance_window = getattr(
218+
spec_config, 'acceptance_window',
219+
None) if spec_config is not None else None
220+
self.acceptance_length_threshold = getattr(
221+
spec_config, 'acceptance_length_threshold',
222+
None) if spec_config is not None else None
223+
self.speculation_permanently_disabled = False
224+
self.speculation_gate = None
225+
if self.acceptance_window and self.acceptance_length_threshold is not None:
226+
self.speculation_gate = SpeculationGate(
227+
self.acceptance_window, self.acceptance_length_threshold)
228+
214229
# response used data
215230
self.response_lock = threading.Lock()
216231
self.response_cv = threading.Condition(self.response_lock)
@@ -1018,10 +1033,15 @@ def _prepare_and_schedule_batch(self):
10181033
self._pad_attention_dp_dummy_request()
10191034

10201035
if self.drafter is not None:
1021-
self.use_spec_decode = self.drafter.should_use_spec_decode(
1022-
self.active_requests, self.max_batch_size,
1023-
self.model_engine.max_num_tokens,
1024-
self.model_engine.spec_config.max_draft_len)
1036+
# Honor permanent disable flag based on rolling acceptance first
1037+
if getattr(self, 'speculation_permanently_disabled', False):
1038+
self.use_spec_decode = False
1039+
else:
1040+
self.use_spec_decode = self.drafter.should_use_spec_decode(
1041+
self.active_requests, self.max_batch_size,
1042+
self.model_engine.max_num_tokens,
1043+
self.model_engine.spec_config.max_draft_len)
1044+
logger.debug(f"Use spec decode: {self.use_spec_decode}")
10251045
self.model_engine.enable_spec_decode = self.use_spec_decode
10261046

10271047
# Set up draft_tokens in active_requests, because they could be used in the scheduling stage.
@@ -2074,6 +2094,30 @@ def _handle_responses(self):
20742094
new_responses.append((req_id, response))
20752095

20762096
if request_done:
2097+
if (self.drafter is not None and getattr(
2098+
self.model_engine, 'enable_spec_decode', False)
2099+
and not self.speculation_permanently_disabled
2100+
and not request.is_dummy and not self.is_warmup):
2101+
if self.speculation_gate is not None:
2102+
# Response handling runs on multiple PP ranks. Only the last PP rank performs
2103+
# sampling; restrict rolling stat updates to it to avoid overcounting.
2104+
if (not getattr(self.dist, 'has_pp',
2105+
False)) or self.dist.is_last_pp_rank:
2106+
avg_decoded = getattr(
2107+
request, 'avg_decoded_tokens_per_iter', None)
2108+
if avg_decoded is not None:
2109+
disabled_now, _ = self.speculation_gate.record_avg_decoded(
2110+
avg_decoded,
2111+
request_id=getattr(request, 'py_request_id',
2112+
None))
2113+
if disabled_now:
2114+
# disable speculation permanently
2115+
# starting from next iteration, _prepare_and_schedule_batch will set self.use_spec_decode to False
2116+
self.speculation_permanently_disabled = True
2117+
else:
2118+
logger.debug(
2119+
f"Request {request.py_request_id} has no avg_decoded_tokens_per_iter"
2120+
)
20772121
if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa:
20782122
requests_to_terminate.append(request)
20792123
else:
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
from collections import deque
2+
from typing import Optional, Tuple
3+
4+
from tensorrt_llm.logger import logger
5+
6+
7+
class SpeculationGate:
8+
"""
9+
Tracks rolling average of accepted draft tokens per iteration over the last N completed requests.
10+
Permanently disables speculation when average falls below a threshold.
11+
"""
12+
13+
def __init__(self, window: int, threshold: float):
14+
self.window = window
15+
self.threshold = threshold
16+
self.acceptance_history: Deque[float] = deque()
17+
self.acceptance_sum: float = 0.0
18+
self.num_completed_for_acceptance = 0
19+
self.disabled = False
20+
logger.debug(
21+
f"[SpeculationGate] SpeculationGate initialized with window={self.window}, threshold={self.threshold}"
22+
)
23+
24+
def reset(self) -> None:
25+
self.acceptance_history.clear()
26+
self.acceptance_sum = 0.0
27+
self.num_completed_for_acceptance = 0
28+
self.disabled = False
29+
30+
def record_avg_decoded(
31+
self,
32+
avg_decoded_tokens_per_iter: float,
33+
request_id: Optional[int] = None) -> Tuple[bool, Optional[float]]:
34+
"""
35+
Record a completed request's avg_decoded_tokens_per_iter.
36+
Returns (disabled_now, current_avg_accept) where disabled_now is True only when the call causes disable.
37+
"""
38+
if self.disabled or self.window is None or self.window <= 0 or self.threshold is None:
39+
return False, None
40+
41+
# Extra Guard: if caller passed None, skip updating the rolling stats
42+
if avg_decoded_tokens_per_iter is None:
43+
return False, None
44+
45+
accepted_len = 0.0
46+
accepted_len = max(0.0, float(avg_decoded_tokens_per_iter) - 1.0)
47+
48+
# Log per-request completion for debug
49+
if request_id is not None:
50+
logger.debug(
51+
f"[SpeculationGate] Request {request_id} completed: avg_decoded={avg_decoded_tokens_per_iter if avg_decoded_tokens_per_iter is not None else 'None'}, accepted_len={accepted_len:.3f}"
52+
)
53+
54+
# O(1) rolling update
55+
self.acceptance_history.append(accepted_len)
56+
logger.debug(
57+
f"[SpeculationGate] Acceptance history: {self.acceptance_history}")
58+
self.acceptance_sum += accepted_len
59+
if len(self.acceptance_history) > self.window:
60+
removed = self.acceptance_history.popleft()
61+
self.acceptance_sum -= removed
62+
63+
self.num_completed_for_acceptance += 1
64+
65+
if self.num_completed_for_acceptance >= self.window:
66+
avg_accept = self.acceptance_sum / len(self.acceptance_history)
67+
if avg_accept < self.threshold:
68+
self.disabled = True
69+
logger.info(
70+
f"[SpeculationGate] Speculative decoding disabled: rolling acceptance avg {avg_accept:.3f} < threshold {self.threshold} over last {self.window} requests"
71+
)
72+
return True, avg_accept
73+
else:
74+
# speculation is still enabled
75+
return False, avg_accept
76+
77+
return False, None

tensorrt_llm/llmapi/llm_args.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,36 @@ class DecodingBaseConfig(StrictBaseModel):
365365
max_concurrency: Optional[int] = None
366366

367367
load_format: Optional[str] = None
368+
# PyTorch only.
369+
# Rolling average window size (N) for acceptance length across completed requests.
370+
# If not set or set to 0, the feature is disabled.
371+
acceptance_window: Optional[int] = None
372+
# PyTorch only.
373+
# Threshold for average acceptance length; speculation will be disabled
374+
# permanently once the rolling average over the last N completed requests
375+
# (N = acceptance_window) drops below this value.
376+
acceptance_length_threshold: Optional[float] = None
377+
378+
# Validate acceptance controls at field level so they run on model creation
379+
@field_validator('acceptance_window')
380+
@classmethod
381+
def _validate_acceptance_window(cls, v: Optional[int]):
382+
if v is None:
383+
return v
384+
if v < 0:
385+
raise ValueError(
386+
f"acceptance_window must be >= 0 (0 disables), got {v}")
387+
return v
388+
389+
@field_validator('acceptance_length_threshold')
390+
@classmethod
391+
def _validate_acceptance_length_threshold(cls, v: Optional[float]):
392+
if v is None:
393+
return v
394+
if v < 0:
395+
raise ValueError(
396+
f"acceptance_length_threshold must be >= 0, got {v}")
397+
return v
368398

369399
# If set, drafting is allowed to use chain drafter.
370400
_allow_chain_drafter: bool = PrivateAttr(True)

tests/unittest/_torch/speculative/test_dynamic_spec_decode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def prepare_draft_tokens(self,
207207
max_num_tokens=4096 * 8,
208208
max_draft_len=4)
209209

210-
# Small token budget ON case: token_cap = 28 // (1+4) = 5 → min(8, 12, 5) = 5 <= 6 → True
210+
# Small token budget ON case: token_cap = 28 // (1+4) = 5 → min(12, 8, 5) = 5 <= 6 → True
211211
active_requests = [object()] * 12
212212
assert drafter.should_use_spec_decode(active_requests,
213213
max_batch_size=8,
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
import os
2+
import sys
3+
import unittest
4+
5+
import pytest
6+
import torch
7+
from utils.llm_data import llm_models_root
8+
from utils.util import similar
9+
10+
from tensorrt_llm import LLM, SamplingParams
11+
from tensorrt_llm._torch.speculative.speculation_gate import SpeculationGate
12+
from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig,
13+
KvCacheConfig)
14+
15+
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
16+
17+
18+
# It tests the end-to-end functionality of the SpeculationGate,
19+
# which will turn off spec decode when the average acceptance length is below the threshold.
20+
# It is set with acceptance window and acceptance threshold in spec_config.
21+
# This test set the max_concurrency to a large value to prevent spec decode turned off due to number of effective requests > max_concurrency,
22+
# So that we can only focus on the turning off effect from the SpeculationGate.
23+
@pytest.mark.high_cuda_memory
24+
def test_spec_gate_e2e():
25+
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
26+
if total_mem_gb < 35:
27+
pytest.skip("Not enough memory to load target + draft model")
28+
models_path = llm_models_root()
29+
eagle_model_dir = f"{models_path}/EAGLE3-LLaMA3.1-Instruct-8B"
30+
target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct"
31+
32+
max_batch_size = 2
33+
max_draft_len = 4
34+
kv_cache_config = KvCacheConfig(enable_block_reuse=True, max_tokens=8192)
35+
cuda_graph_config = CudaGraphConfig(batch_sizes=[1])
36+
37+
llm_common_config = dict(
38+
model=target_model_dir,
39+
attn_backend="TRTLLM",
40+
disable_overlap_scheduler=True,
41+
cuda_graph_config=cuda_graph_config,
42+
max_batch_size=max_batch_size,
43+
kv_cache_config=kv_cache_config,
44+
max_seq_len=4096,
45+
)
46+
47+
spec_config = EagleDecodingConfig(
48+
max_draft_len=max_draft_len,
49+
speculative_model_dir=eagle_model_dir,
50+
# Llama 3 does not support one model eagle.
51+
eagle3_one_model=False,
52+
max_concurrency=10000,
53+
acceptance_window=5,
54+
acceptance_length_threshold=0.6,
55+
)
56+
57+
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
58+
# Output tests
59+
prompts = [
60+
"The capital of France is",
61+
"The president of the United States is",
62+
"What is the capital of Australia?",
63+
"Explain in one sentence why the sky is blue.",
64+
"Who wrote the book 'Pride and Prejudice'?",
65+
"List three U.S. national holidays in the year 2025.",
66+
"What is the currency of Japan?",
67+
"How many players are on a basketball court for one team?",
68+
"List three primary colors.",
69+
]
70+
sampling_params = SamplingParams(max_tokens=32, temperature=0)
71+
72+
results_spec = llm_spec.generate(prompts, sampling_params)
73+
generated_text_spec = [result.outputs[0].text for result in results_spec]
74+
llm_spec.shutdown()
75+
76+
llm_ref = LLM(**llm_common_config)
77+
results_ref = llm_ref.generate(prompts, sampling_params)
78+
generated_text_ref = [result.outputs[0].text for result in results_ref]
79+
llm_ref.shutdown()
80+
81+
for text_spec, text_ref in zip(generated_text_spec, generated_text_ref):
82+
assert similar(text_spec, text_ref)
83+
84+
85+
def test_returns_none_until_window_and_enabled_when_above_threshold():
86+
gate = SpeculationGate(window=3, threshold=0.5)
87+
88+
disabled, avg = gate.record_avg_decoded(2.0, request_id=1)
89+
assert disabled is False and avg is None
90+
assert gate.disabled is False
91+
92+
disabled, avg = gate.record_avg_decoded(2.0, request_id=2)
93+
assert disabled is False and avg is None
94+
assert gate.disabled is False
95+
96+
disabled, avg = gate.record_avg_decoded(2.0, request_id=3)
97+
assert disabled is False
98+
assert avg == pytest.approx(1.0, rel=1e-6)
99+
assert gate.disabled is False
100+
101+
102+
def test_disables_when_avg_below_threshold_and_stays_disabled():
103+
gate = SpeculationGate(window=3, threshold=0.7)
104+
105+
gate.record_avg_decoded(1.1)
106+
gate.record_avg_decoded(1.2)
107+
108+
disabled, avg = gate.record_avg_decoded(1.3)
109+
assert disabled is True
110+
assert avg == pytest.approx(0.2, rel=1e-6)
111+
assert gate.disabled is True
112+
113+
# Once disabled, subsequent calls do nothing and return (False, None)
114+
disabled, avg = gate.record_avg_decoded(100.0)
115+
assert disabled is False and avg is None
116+
assert gate.disabled is True
117+
118+
disabled, avg = gate.record_avg_decoded(200.0)
119+
assert disabled is False and avg is None
120+
assert gate.disabled is True
121+
122+
123+
def test_rolling_window_and_disable_on_drop():
124+
gate = SpeculationGate(window=3, threshold=0.8)
125+
126+
# First three high-acceptance requests keep it enabled
127+
gate.record_avg_decoded(2.0)
128+
gate.record_avg_decoded(2.0)
129+
disabled, avg = gate.record_avg_decoded(2.0)
130+
assert disabled is False
131+
assert avg == pytest.approx(1.0, rel=1e-6)
132+
assert gate.disabled is False
133+
134+
# Fourth lower value enters window -> average drops below threshold -> disable
135+
disabled, avg = gate.record_avg_decoded(1.2)
136+
assert disabled is True
137+
assert avg == pytest.approx((1.0 + 1.0 + 0.2) / 3.0, rel=1e-6)
138+
assert gate.disabled is True
139+
140+
141+
if __name__ == "__main__":
142+
unittest.main()

0 commit comments

Comments
 (0)