@@ -374,5 +374,64 @@ def test_multi_eagle3(use_one_model: bool):
374374 pass
375375
376376
377+ @pytest .mark .parametrize ("disable_overlap_scheduler" , [True , False ])
378+ def test_eagle3_cuda_graph_padding (disable_overlap_scheduler : bool ):
379+ """Test CUDA graph padding with 3 requests and max_batch_size=4.
380+
381+ This test verifies that when using CUDA graph with padding enabled,
382+ the system properly reserves one additional slot for the padded dummy request.
383+ Without this fix, there would be errors caused by no free slot.
384+ """
385+ attn_backend = "TRTLLM"
386+ enable_block_reuse = False
387+ use_one_model = False
388+ enable_chunked_prefill = False
389+
390+ total_mem_gb = torch .cuda .get_device_properties (0 ).total_memory / 1e9
391+ if total_mem_gb < 35 :
392+ pytest .skip ("Not enough memory to load target + draft model" )
393+
394+ models_path = llm_models_root ()
395+ eagle_model_dir = f"{ models_path } /EAGLE3-LLaMA3.1-Instruct-8B"
396+ target_model_dir = f"{ models_path } /llama-3.1-model/Llama-3.1-8B-Instruct"
397+
398+ # Test with 3 requests and max_batch_size=4 to trigger padding
399+ max_batch_size = 4
400+ max_draft_len = 4
401+ kv_cache_config = KvCacheConfig (enable_block_reuse = enable_block_reuse ,
402+ max_tokens = 8192 )
403+ cuda_graph_config = CudaGraphConfig (batch_sizes = [1 , 2 , 4 ],
404+ enable_padding = True )
405+
406+ llm_common_config = dict (
407+ model = target_model_dir ,
408+ attn_backend = attn_backend ,
409+ disable_overlap_scheduler = disable_overlap_scheduler ,
410+ cuda_graph_config = cuda_graph_config ,
411+ max_batch_size = max_batch_size ,
412+ kv_cache_config = kv_cache_config ,
413+ max_seq_len = 8192 ,
414+ enable_chunked_prefill = enable_chunked_prefill ,
415+ )
416+
417+ spec_config = EagleDecodingConfig (
418+ max_draft_len = max_draft_len ,
419+ speculative_model_dir = eagle_model_dir ,
420+ eagle3_one_model = use_one_model ,
421+ )
422+
423+ # Create the LLM instance
424+ llm_spec = LLM (** llm_common_config , speculative_config = spec_config )
425+
426+ prompts = [
427+ "The capital of France is" , "The president of the United States is" ,
428+ "The future of AI is"
429+ ]
430+
431+ sampling_params = SamplingParams (max_tokens = 20 , temperature = 0 )
432+ llm_spec .generate (prompts , sampling_params )
433+ llm_spec .shutdown ()
434+
435+
377436if __name__ == "__main__" :
378437 unittest .main ()
0 commit comments