1717
1818
1919@pytest .mark .parametrize (
20- "use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill,use_chain_drafter" ,
20+ "use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill,use_chain_drafter,fp8_target " ,
2121 [
22- [True , "TRTLLM" , True , False , False , False , True ],
23- [True , "TRTLLM" , True , False , False , False , False ],
24- [False , "TRTLLM" , True , False , False , False , True ],
25- [False , "TRTLLM" , True , False , False , False , False ],
26- [True , "FLASHINFER" , True , False , False , False , True ],
27- [False , "FLASHINFER" , True , False , False , False , True ],
28- [False , "TRTLLM" , False , True , True , False , True ],
29- [True , "TRTLLM" , False , True , True , False , True ],
30- [True , "TRTLLM" , True , False , True , True , True ],
31- [True , "TRTLLM" , True , False , True , False , True ],
22+ [True , "TRTLLM" , True , False , False , False , True , False ],
23+ [True , "TRTLLM" , True , False , False , False , False , False ],
24+ [False , "TRTLLM" , True , False , False , False , True , False ],
25+ [False , "TRTLLM" , True , False , False , False , False , False ],
26+ [True , "FLASHINFER" , True , False , False , False , True , False ],
27+ [False , "FLASHINFER" , True , False , False , False , True , False ],
28+ [False , "TRTLLM" , False , True , True , False , True , False ],
29+ [True , "TRTLLM" , False , True , True , False , True , False ],
30+ [True , "TRTLLM" , True , False , True , True , True , False ],
31+ [True , "TRTLLM" , True , False , True , False , True , False ],
3232 # TODO: nvbugs/5461761
33- # [True, "TRTLLM", True, False, False, True, True],
34- [True , "TRTLLM" , False , False , False , False , True ],
35- [False , "TRTLLM" , False , False , False , False , True ],
36- [True , "TRTLLM" , False , False , False , False , False ],
37- [False , "TRTLLM" , False , False , False , False , False ],
38- [True , "TRTLLM" , False , False , False , True , True ],
39- [True , "TRTLLM" , False , False , False , True , False ],
33+ # [True, "TRTLLM", True, False, False, True, True, False ],
34+ [True , "TRTLLM" , False , False , False , False , True , False ],
35+ [False , "TRTLLM" , False , False , False , False , True , False ],
36+ [True , "TRTLLM" , False , False , False , False , False , False ],
37+ [False , "TRTLLM" , False , False , False , False , False , False ],
38+ [True , "TRTLLM" , False , False , False , True , True , False ],
39+ [True , "TRTLLM" , False , False , False , True , False , False ],
4040 # TODO: nvbugs/5522851
41- # [True, "FLASHINFER", False, False, False, False, True],
42- [False , "FLASHINFER" , False , False , False , False , True ],
41+ # [True, "FLASHINFER", False, False, False, False, True, False],
42+ [False , "FLASHINFER" , False , False , False , False , True , False ],
43+ [True , "TRTLLM" , True , True , True , True , True , True ],
4344 ])
4445@pytest .mark .high_cuda_memory
4546def test_llama_eagle3 (use_cuda_graph : bool , attn_backend : str ,
4647 disable_overlap_scheduler : bool , enable_block_reuse : bool ,
4748 use_one_model : bool , enable_chunked_prefill : bool ,
48- use_chain_drafter : bool ):
49+ use_chain_drafter : bool , fp8_target : bool ):
4950 # Eagle3 one model works with overlap scheduler and block reuse.
5051 total_mem_gb = torch .cuda .get_device_properties (0 ).total_memory / 1e9
5152 if total_mem_gb < 35 :
@@ -54,13 +55,18 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
5455 models_path = llm_models_root ()
5556 eagle_model_dir = f"{ models_path } /EAGLE3-LLaMA3.1-Instruct-8B"
5657 target_model_dir = f"{ models_path } /llama-3.1-model/Llama-3.1-8B-Instruct"
58+ kv_cache_dtype = 'auto'
59+ if fp8_target :
60+ target_model_dir = f"{ models_path } /llama-3.1-model/Llama-3.1-8B-Instruct-FP8"
61+ kv_cache_dtype = 'fp8'
5762
5863 # bs > 1 gives non-deterministic when doing IFB. There are slight chances
5964 # that ref and spec does not match 100%
6065 max_batch_size = 1
6166 max_draft_len = 4
6267 kv_cache_config = KvCacheConfig (enable_block_reuse = enable_block_reuse ,
63- max_tokens = 8192 )
68+ max_tokens = 8192 ,
69+ dtype = kv_cache_dtype )
6470 cuda_graph_config = CudaGraphConfig (
6571 batch_sizes = [1 ]) if use_cuda_graph else None
6672
0 commit comments