@@ -29,19 +29,30 @@ def enforce_single_worker():
2929# # ============================================================================
3030# # test 1: Generation correctness check
3131# # ============================================================================
32+ @pytest .mark .skip ("https://nvbugspro.nvidia.com/bug/5680911" )
3233@pytest .mark .parametrize (
3334 "drafter_type,schedule" ,
3435 [
35- ("ngram" , {1 : 3 , 4 : 2 , 8 : 1 }),
36- ("model_drafter" , {1 : 3 , 4 : 2 , 8 : 1 }),
36+ ("ngram" , {
37+ 1 : 3 ,
38+ 4 : 2 ,
39+ 8 : 1
40+ }),
41+ ("model_drafter" , {
42+ 1 : 3 ,
43+ 4 : 2 ,
44+ 8 : 1
45+ }),
3746 ],
3847)
3948@pytest .mark .high_cuda_memory
4049def test_correctness_across_batch_sizes (drafter_type : str , schedule : dict ):
4150 total_mem_gb = torch .cuda .get_device_properties (0 ).total_memory / 1e9
4251 memory_required = 30 if drafter_type == "model_drafter" else 20
4352 if total_mem_gb < memory_required :
44- pytest .skip (f"Not enough memory (need { memory_required } GB, have { total_mem_gb :.1f} GB)" )
53+ pytest .skip (
54+ f"Not enough memory (need { memory_required } GB, have { total_mem_gb :.1f} GB)"
55+ )
4556
4657 models_path = llm_models_root ()
4758 target_model = f"{ models_path } /llama-3.1-model/Llama-3.1-8B-Instruct"
@@ -50,9 +61,9 @@ def test_correctness_across_batch_sizes(drafter_type: str, schedule: dict):
5061 max_batch_size = 8
5162 max_draft_len = max (schedule .values ()) # Use max from schedule
5263
53- kv_cache_config = KvCacheConfig (
54- enable_block_reuse = False , enable_partial_reuse = False , max_tokens = 1024
55- )
64+ kv_cache_config = KvCacheConfig (enable_block_reuse = False ,
65+ enable_partial_reuse = False ,
66+ max_tokens = 1024 )
5667
5768 llm_common_config = dict (
5869 model = target_model ,
@@ -101,13 +112,15 @@ def test_correctness_across_batch_sizes(drafter_type: str, schedule: dict):
101112 ignore_eos = True , # Prevent early stopping differences
102113 top_k = 1 ,
103114 top_p = 1.0 ,
104- )
105- for i in range (len (prompts ))
115+ ) for i in range (len (prompts ))
106116 ]
107117 # With dynamic draft_len_schedule
108118 llm_with_schedule = LLM (** llm_common_config , speculative_config = spec_config )
109- results_with_schedule = llm_with_schedule .generate (prompts , sampling_params_list )
110- generated_text_with_schedule = [result .outputs [0 ].text for result in results_with_schedule ]
119+ results_with_schedule = llm_with_schedule .generate (prompts ,
120+ sampling_params_list )
121+ generated_text_with_schedule = [
122+ result .outputs [0 ].text for result in results_with_schedule
123+ ]
111124 llm_with_schedule .shutdown ()
112125 # Reference: spec decode with fixed max_draft_len (no schedule)
113126 if drafter_type == "ngram" :
@@ -131,12 +144,12 @@ def test_correctness_across_batch_sizes(drafter_type: str, schedule: dict):
131144 llm_fixed .shutdown ()
132145
133146 # Verify correctness: spec decode with schedule should match spec decode without schedule
134- for text_schedule , text_fixed in zip (generated_text_with_schedule , generated_text_fixed ):
147+ for text_schedule , text_fixed in zip (generated_text_with_schedule ,
148+ generated_text_fixed ):
135149 assert similar (text_schedule , text_fixed ), (
136150 f"{ drafter_type } output with draft_len_schedule should match output with fixed draft_len. Got:\n "
137151 f"With schedule: { text_schedule } \n "
138- f"Fixed: { text_fixed } "
139- )
152+ f"Fixed: { text_fixed } " )
140153
141154
142155# # ============================================================================
@@ -145,12 +158,25 @@ def test_correctness_across_batch_sizes(drafter_type: str, schedule: dict):
145158@pytest .mark .parametrize (
146159 "drafter_type,draft_schedule" ,
147160 [
148- ("ngram" , {1 : 5 , 4 : 4 , 5 : 3 , 6 : 2 , 7 : 1 }),
149- ("model_drafter" , {1 : 5 , 4 : 4 , 5 : 3 , 6 : 2 , 7 : 1 }),
161+ ("ngram" , {
162+ 1 : 5 ,
163+ 4 : 4 ,
164+ 5 : 3 ,
165+ 6 : 2 ,
166+ 7 : 1
167+ }),
168+ ("model_drafter" , {
169+ 1 : 5 ,
170+ 4 : 4 ,
171+ 5 : 3 ,
172+ 6 : 2 ,
173+ 7 : 1
174+ }),
150175 ],
151176)
152177@pytest .mark .high_cuda_memory
153- def test_draft_len_schedule_functionality (drafter_type : str , draft_schedule : dict ):
178+ def test_draft_len_schedule_functionality (drafter_type : str ,
179+ draft_schedule : dict ):
154180 if not torch .cuda .is_available ():
155181 pytest .skip ("CUDA not available" )
156182
@@ -161,9 +187,9 @@ def test_draft_len_schedule_functionality(drafter_type: str, draft_schedule: dic
161187 pytest .skip ("Not enough memory" )
162188 max_batch_size = 7
163189
164- kv_cache_config = KvCacheConfig (
165- enable_block_reuse = False , enable_partial_reuse = False , max_tokens = 1024
166- )
190+ kv_cache_config = KvCacheConfig (enable_block_reuse = False ,
191+ enable_partial_reuse = False ,
192+ max_tokens = 1024 )
167193
168194 llm_common_config = dict (
169195 model = llm_models_root () / "llama-3.1-model" / "Meta-Llama-3.1-8B" ,
@@ -184,9 +210,8 @@ def test_draft_len_schedule_functionality(drafter_type: str, draft_schedule: dic
184210 else :
185211 spec_config = DraftTargetDecodingConfig (
186212 max_draft_len = 5 ,
187- speculative_model_dir = str (
188- llm_models_root () / "llama-3.2-models" / "Llama-3.2-3B-Instruct"
189- ),
213+ speculative_model_dir = str (llm_models_root () / "llama-3.2-models" /
214+ "Llama-3.2-3B-Instruct" ),
190215 draft_len_schedule = draft_schedule ,
191216 )
192217 prompts = ["The capital of France is" for i in range (7 )]
@@ -200,8 +225,7 @@ def test_draft_len_schedule_functionality(drafter_type: str, draft_schedule: dic
200225 ignore_eos = True , # Prevent early stopping
201226 top_k = 1 ,
202227 top_p = 1.0 ,
203- )
204- for i in range (7 )
228+ ) for i in range (7 )
205229 ]
206230
207231 llm_spec = LLM (** llm_common_config , speculative_config = spec_config )
@@ -223,18 +247,19 @@ def mock_should_use_spec_decode(*args, **kwargs):
223247 drafter .should_use_spec_decode = mock_should_use_spec_decode
224248
225249 # 2. Instrument update_max_total_draft_tokens to capture when draft_len changes
226- def instrumented_update_max_total_draft_tokens (new_max_total_draft_tokens : int ):
250+ def instrumented_update_max_total_draft_tokens (
251+ new_max_total_draft_tokens : int ):
227252 batch_size_active = len (executor .active_requests )
228253 original_update_max_total_draft_tokens (new_max_total_draft_tokens )
229254
230- iteration_data .append (
231- {
232- "batch_size_active " : batch_size_active ,
233- "drafter_max_draft_tokens" : new_max_total_draft_tokens ,
234- "use_spec_decode" : None , # Will be filled after _prepare_and_schedule_batch completes
235- "actual_draft_lens" : [], # Will be filled after prepare_draft_tokens
236- }
237- )
255+ iteration_data .append ({
256+ "batch_size_active" : batch_size_active ,
257+ "drafter_max_draft_tokens " : new_max_total_draft_tokens ,
258+ "use_spec_decode" :
259+ None , # Will be filled after _prepare_and_schedule_batch completes
260+ "actual_draft_lens" :
261+ [], # Will be filled after prepare_draft_tokens
262+ } )
238263
239264 drafter .update_max_total_draft_tokens = instrumented_update_max_total_draft_tokens
240265
@@ -247,7 +272,8 @@ def instrumented_prepare_draft(scheduled_batch, resource_manager):
247272
248273 actual_draft_lens = []
249274 for req in scheduled_batch .generation_requests :
250- draft_len = len (req .py_draft_tokens ) if req .py_draft_tokens else 0
275+ draft_len = len (
276+ req .py_draft_tokens ) if req .py_draft_tokens else 0
251277 actual_draft_lens .append (draft_len )
252278
253279 iteration_data [- 1 ]["actual_draft_lens" ] = actual_draft_lens
@@ -315,5 +341,4 @@ def instrumented_prepare_draft(scheduled_batch, resource_manager):
315341 for req_idx , actual_len in enumerate (actual_lens ):
316342 assert actual_len == drafter_tokens , (
317343 f"Iter { idx } , req { req_idx } : ModelDrafter produced { actual_len } "
318- f"!= max_draft_tokens { drafter_tokens } "
319- )
344+ f"!= max_draft_tokens { drafter_tokens } " )
0 commit comments