Skip to content

Commit 8922ca8

Browse files
committed
Change from correctness check to functional check and unwaive the test.
Signed-off-by: Zheyu Fu <zheyuf@NVIDIA.com>
1 parent 7175d89 commit 8922ca8

File tree

2 files changed

+85
-30
lines changed

2 files changed

+85
-30
lines changed

tests/integration/test_lists/waives.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,6 @@ accuracy/test_llm_api_pytorch.py::TestNemotronH_56B_Base::test_auto_dtype[tp8-cu
403403
accuracy/test_llm_api_pytorch.py::TestNemotronUltra::test_fp8_prequantized[tp8ep4-cuda_graph=True] SKIP (https://nvbugs/5707145)
404404
accuracy/test_llm_api_pytorch.py::TestNemotronUltra::test_fp8_prequantized[tp8-cuda_graph=True] SKIP (https://nvbugs/5707145)
405405
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[cutlass-auto] SKIP (https://nvbugs/5596343)
406-
unittest/_torch/speculative/test_spec_gate.py::test_spec_gate_e2e SKIP (https://nvbugs/5710045)
407406
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_ngram SKIP (https://nvbugs/5569696)
408407
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[throughput_mtp_trtllm] SKIP (https://nvbugs/5715568)
409408
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[throughput_mtp] SKIP (https://nvbugs/5715568)

tests/unittest/_torch/speculative/test_spec_gate.py

Lines changed: 85 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,34 @@
11
import os
22
import sys
33
import unittest
4+
from unittest.mock import patch
45

56
import pytest
67
import torch
78
from utils.llm_data import llm_models_root
8-
from utils.util import similar, skip_blackwell
99

1010
from tensorrt_llm import LLM, SamplingParams
1111
from tensorrt_llm._torch.speculative.speculation_gate import SpeculationGate
1212
from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig,
1313
KvCacheConfig)
14+
from tensorrt_llm.logger import logger
1415

1516
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
1617

1718

18-
# It tests the end-to-end functionality of the SpeculationGate,
19-
# which will turn off spec decode when the average acceptance length is below the threshold.
20-
# It is set with acceptance window and acceptance threshold in spec_config.
21-
# This test set the max_concurrency to a large value to prevent spec decode turned off due to number of effective requests > max_concurrency,
22-
# So that we can only focus on the turning off effect from the SpeculationGate.
23-
@skip_blackwell # TODO: Remove after fixing TRTLLM-GEN FMHA segfault on Blackwell. NVBugs: https://nvbugspro.nvidia.com/bug/5698292
19+
@pytest.fixture(scope="function")
20+
def enforce_single_worker(monkeypatch):
21+
"""Mock functions don't work with multiple processes, so we enforce single worker."""
22+
monkeypatch.setenv("TLLM_WORKER_USE_SINGLE_PROCESS", "1")
23+
yield
24+
25+
26+
# Tests that the SpeculationGate correctly disables speculative decoding
27+
# when the average acceptance rate drops below the threshold.
28+
# This test uses a mock to simulate low acceptance rates and verifies
29+
# that the spec gate triggers and disables speculation.
2430
@pytest.mark.high_cuda_memory
25-
def test_spec_gate_e2e():
31+
def test_spec_gate_e2e(enforce_single_worker):
2632
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
2733
if total_mem_gb < 35:
2834
pytest.skip("Not enough memory to load target + draft model")
@@ -32,6 +38,8 @@ def test_spec_gate_e2e():
3238

3339
max_batch_size = 2
3440
max_draft_len = 4
41+
acceptance_window = 3
42+
acceptance_threshold = 0.6
3543
kv_cache_config = KvCacheConfig(enable_block_reuse=True, max_tokens=8192)
3644
cuda_graph_config = CudaGraphConfig(batch_sizes=[1])
3745

@@ -48,39 +56,87 @@ def test_spec_gate_e2e():
4856
spec_config = EagleDecodingConfig(
4957
max_draft_len=max_draft_len,
5058
speculative_model_dir=eagle_model_dir,
51-
# Llama 3 does not support one model eagle.
5259
eagle3_one_model=False,
5360
max_concurrency=10000,
54-
acceptance_window=5,
55-
acceptance_length_threshold=0.6,
61+
acceptance_window=acceptance_window,
62+
acceptance_length_threshold=acceptance_threshold,
5663
)
5764

58-
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
59-
# Output tests
6065
prompts = [
6166
"The capital of France is",
6267
"The president of the United States is",
6368
"What is the capital of Australia?",
64-
"Explain in one sentence why the sky is blue.",
65-
"Who wrote the book 'Pride and Prejudice'?",
66-
"List three U.S. national holidays in the year 2025.",
67-
"What is the currency of Japan?",
68-
"How many players are on a basketball court for one team?",
69-
"List three primary colors.",
7069
]
71-
sampling_params = SamplingParams(max_tokens=32, temperature=0)
70+
sampling_params = SamplingParams(max_tokens=20, temperature=0)
71+
72+
# Track calls to record_avg_decoded and the disabled state
73+
gate_state = {"record_calls": [], "gate_disabled": False}
74+
75+
original_record_avg_decoded = SpeculationGate.record_avg_decoded
76+
77+
def mock_record_avg_decoded(self,
78+
avg_decoded_tokens_per_iter,
79+
request_id=None):
80+
"""
81+
Mock that simulates low acceptance rate (1.2 tokens/iter = 0.2 accepted).
82+
This is below the threshold of 0.6, so the gate should trigger after the window fills.
83+
"""
84+
# Simulate low acceptance: avg_decoded = 1.2 means accepted_len = 0.2
85+
# This is below threshold (0.6), so gate should trigger
86+
simulated_low_avg = 1.2
87+
disabled_now, avg = original_record_avg_decoded(self, simulated_low_avg,
88+
request_id)
89+
90+
gate_state["record_calls"].append({
91+
"original_avg": avg_decoded_tokens_per_iter,
92+
"simulated_avg": simulated_low_avg,
93+
"disabled_now": disabled_now,
94+
"avg_accept": avg,
95+
"request_id": request_id,
96+
})
97+
if disabled_now:
98+
gate_state["gate_disabled"] = True
99+
100+
return disabled_now, avg
72101

73-
results_spec = llm_spec.generate(prompts, sampling_params)
74-
generated_text_spec = [result.outputs[0].text for result in results_spec]
75-
llm_spec.shutdown()
102+
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
103+
104+
with patch.object(SpeculationGate, 'record_avg_decoded',
105+
mock_record_avg_decoded):
106+
llm_spec.generate(prompts, sampling_params)
76107

77-
llm_ref = LLM(**llm_common_config)
78-
results_ref = llm_ref.generate(prompts, sampling_params)
79-
generated_text_ref = [result.outputs[0].text for result in results_ref]
80-
llm_ref.shutdown()
108+
# Verify the mock was called (requests completed)
109+
assert len(gate_state["record_calls"]
110+
) > 0, "record_avg_decoded should have been called"
81111

82-
for text_spec, text_ref in zip(generated_text_spec, generated_text_ref):
83-
assert similar(text_spec, text_ref)
112+
# Verify the gate was disabled after enough requests with low acceptance
113+
assert gate_state["gate_disabled"], \
114+
f"Gate should have been disabled with simulated low acceptance. Calls: {gate_state['record_calls']}"
115+
116+
# Verify the gate triggered at the right time (after window is filled)
117+
# The gate should trigger on the `acceptance_window`-th call (index = window - 1)
118+
disable_indices = [
119+
i for i, call in enumerate(gate_state["record_calls"])
120+
if call["disabled_now"]
121+
]
122+
assert len(disable_indices) == 1, \
123+
f"Gate should have triggered exactly once, but triggered at indices: {disable_indices}"
124+
assert disable_indices[0] >= acceptance_window - 1, \
125+
f"Gate should trigger after window ({acceptance_window}) is filled, but triggered at index {disable_indices[0]}"
126+
127+
# Verify the average acceptance was below threshold when disabled
128+
disable_call = gate_state["record_calls"][disable_indices[0]]
129+
assert disable_call["avg_accept"] is not None
130+
assert disable_call["avg_accept"] < acceptance_threshold, \
131+
f"Avg acceptance ({disable_call['avg_accept']}) should be below threshold ({acceptance_threshold})"
132+
133+
logger.debug(
134+
f"Gate correctly triggered after {disable_indices[0] + 1} requests")
135+
logger.debug(
136+
f"Final avg acceptance: {disable_call['avg_accept']:.3f} < threshold {acceptance_threshold}"
137+
)
138+
139+
llm_spec.shutdown()
84140

85141

86142
def test_returns_none_until_window_and_enabled_when_above_threshold():

0 commit comments

Comments
 (0)