11import os
22import sys
33import unittest
4+ from unittest .mock import patch
45
56import pytest
67import torch
78from utils .llm_data import llm_models_root
8- from utils .util import similar , skip_blackwell
99
1010from tensorrt_llm import LLM , SamplingParams
1111from tensorrt_llm ._torch .speculative .speculation_gate import SpeculationGate
1212from tensorrt_llm .llmapi import (CudaGraphConfig , EagleDecodingConfig ,
1313 KvCacheConfig )
14+ from tensorrt_llm .logger import logger
1415
1516sys .path .append (os .path .join (os .path .dirname (__file__ ), '..' ))
1617
1718
18- # It tests the end-to-end functionality of the SpeculationGate,
19- # which will turn off spec decode when the average acceptance length is below the threshold.
20- # It is set with acceptance window and acceptance threshold in spec_config.
21- # This test set the max_concurrency to a large value to prevent spec decode turned off due to number of effective requests > max_concurrency,
22- # So that we can only focus on the turning off effect from the SpeculationGate.
23- @skip_blackwell # TODO: Remove after fixing TRTLLM-GEN FMHA segfault on Blackwell. NVBugs: https://nvbugspro.nvidia.com/bug/5698292
19+ @pytest .fixture (scope = "function" )
20+ def enforce_single_worker (monkeypatch ):
21+ """Mock functions don't work with multiple processes, so we enforce single worker."""
22+ monkeypatch .setenv ("TLLM_WORKER_USE_SINGLE_PROCESS" , "1" )
23+ yield
24+
25+
26+ # Tests that the SpeculationGate correctly disables speculative decoding
27+ # when the average acceptance rate drops below the threshold.
28+ # This test uses a mock to simulate low acceptance rates and verifies
29+ # that the spec gate triggers and disables speculation.
2430@pytest .mark .high_cuda_memory
25- def test_spec_gate_e2e ():
31+ def test_spec_gate_e2e (enforce_single_worker ):
2632 total_mem_gb = torch .cuda .get_device_properties (0 ).total_memory / 1e9
2733 if total_mem_gb < 35 :
2834 pytest .skip ("Not enough memory to load target + draft model" )
@@ -32,6 +38,8 @@ def test_spec_gate_e2e():
3238
3339 max_batch_size = 2
3440 max_draft_len = 4
41+ acceptance_window = 3
42+ acceptance_threshold = 0.6
3543 kv_cache_config = KvCacheConfig (enable_block_reuse = True , max_tokens = 8192 )
3644 cuda_graph_config = CudaGraphConfig (batch_sizes = [1 ])
3745
@@ -48,39 +56,87 @@ def test_spec_gate_e2e():
4856 spec_config = EagleDecodingConfig (
4957 max_draft_len = max_draft_len ,
5058 speculative_model_dir = eagle_model_dir ,
51- # Llama 3 does not support one model eagle.
5259 eagle3_one_model = False ,
5360 max_concurrency = 10000 ,
54- acceptance_window = 5 ,
55- acceptance_length_threshold = 0.6 ,
61+ acceptance_window = acceptance_window ,
62+ acceptance_length_threshold = acceptance_threshold ,
5663 )
5764
58- llm_spec = LLM (** llm_common_config , speculative_config = spec_config )
59- # Output tests
6065 prompts = [
6166 "The capital of France is" ,
6267 "The president of the United States is" ,
6368 "What is the capital of Australia?" ,
64- "Explain in one sentence why the sky is blue." ,
65- "Who wrote the book 'Pride and Prejudice'?" ,
66- "List three U.S. national holidays in the year 2025." ,
67- "What is the currency of Japan?" ,
68- "How many players are on a basketball court for one team?" ,
69- "List three primary colors." ,
7069 ]
71- sampling_params = SamplingParams (max_tokens = 32 , temperature = 0 )
70+ sampling_params = SamplingParams (max_tokens = 20 , temperature = 0 )
71+
72+ # Track calls to record_avg_decoded and the disabled state
73+ gate_state = {"record_calls" : [], "gate_disabled" : False }
74+
75+ original_record_avg_decoded = SpeculationGate .record_avg_decoded
76+
77+ def mock_record_avg_decoded (self ,
78+ avg_decoded_tokens_per_iter ,
79+ request_id = None ):
80+ """
81+ Mock that simulates low acceptance rate (1.2 tokens/iter = 0.2 accepted).
82+ This is below the threshold of 0.6, so the gate should trigger after the window fills.
83+ """
84+ # Simulate low acceptance: avg_decoded = 1.2 means accepted_len = 0.2
85+ # This is below threshold (0.6), so gate should trigger
86+ simulated_low_avg = 1.2
87+ disabled_now , avg = original_record_avg_decoded (self , simulated_low_avg ,
88+ request_id )
89+
90+ gate_state ["record_calls" ].append ({
91+ "original_avg" : avg_decoded_tokens_per_iter ,
92+ "simulated_avg" : simulated_low_avg ,
93+ "disabled_now" : disabled_now ,
94+ "avg_accept" : avg ,
95+ "request_id" : request_id ,
96+ })
97+ if disabled_now :
98+ gate_state ["gate_disabled" ] = True
99+
100+ return disabled_now , avg
72101
73- results_spec = llm_spec .generate (prompts , sampling_params )
74- generated_text_spec = [result .outputs [0 ].text for result in results_spec ]
75- llm_spec .shutdown ()
102+ llm_spec = LLM (** llm_common_config , speculative_config = spec_config )
103+
104+ with patch .object (SpeculationGate , 'record_avg_decoded' ,
105+ mock_record_avg_decoded ):
106+ llm_spec .generate (prompts , sampling_params )
76107
77- llm_ref = LLM (** llm_common_config )
78- results_ref = llm_ref .generate (prompts , sampling_params )
79- generated_text_ref = [result .outputs [0 ].text for result in results_ref ]
80- llm_ref .shutdown ()
108+ # Verify the mock was called (requests completed)
109+ assert len (gate_state ["record_calls" ]
110+ ) > 0 , "record_avg_decoded should have been called"
81111
82- for text_spec , text_ref in zip (generated_text_spec , generated_text_ref ):
83- assert similar (text_spec , text_ref )
112+ # Verify the gate was disabled after enough requests with low acceptance
113+ assert gate_state ["gate_disabled" ], \
114+ f"Gate should have been disabled with simulated low acceptance. Calls: { gate_state ['record_calls' ]} "
115+
116+ # Verify the gate triggered at the right time (after window is filled)
117+ # The gate should trigger on the `acceptance_window`-th call (index = window - 1)
118+ disable_indices = [
119+ i for i , call in enumerate (gate_state ["record_calls" ])
120+ if call ["disabled_now" ]
121+ ]
122+ assert len (disable_indices ) == 1 , \
123+ f"Gate should have triggered exactly once, but triggered at indices: { disable_indices } "
124+ assert disable_indices [0 ] >= acceptance_window - 1 , \
125+ f"Gate should trigger after window ({ acceptance_window } ) is filled, but triggered at index { disable_indices [0 ]} "
126+
127+ # Verify the average acceptance was below threshold when disabled
128+ disable_call = gate_state ["record_calls" ][disable_indices [0 ]]
129+ assert disable_call ["avg_accept" ] is not None
130+ assert disable_call ["avg_accept" ] < acceptance_threshold , \
131+ f"Avg acceptance ({ disable_call ['avg_accept' ]} ) should be below threshold ({ acceptance_threshold } )"
132+
133+ logger .debug (
134+ f"Gate correctly triggered after { disable_indices [0 ] + 1 } requests" )
135+ logger .debug (
136+ f"Final avg acceptance: { disable_call ['avg_accept' ]:.3f} < threshold { acceptance_threshold } "
137+ )
138+
139+ llm_spec .shutdown ()
84140
85141
86142def test_returns_none_until_window_and_enabled_when_above_threshold ():
0 commit comments