1717
1818
1919@pytest .mark .parametrize (
20- "use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill,use_chain_drafter" ,
20+ "use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill,use_chain_drafter,fp8_target " ,
2121 [
22- [True , "TRTLLM" , True , False , False , False , True ],
23- [True , "TRTLLM" , True , False , False , False , False ],
24- [False , "TRTLLM" , True , False , False , False , True ],
25- [False , "TRTLLM" , True , False , False , False , False ],
26- [True , "FLASHINFER" , True , False , False , False , True ],
27- [False , "FLASHINFER" , True , False , False , False , True ],
28- [False , "TRTLLM" , False , True , True , False , True ],
29- [True , "TRTLLM" , False , True , True , False , True ],
30- [True , "TRTLLM" , True , False , True , True , True ],
31- [True , "TRTLLM" , True , False , True , False , True ],
22+ [True , "TRTLLM" , True , False , False , False , True , False ],
23+ [True , "TRTLLM" , True , False , False , False , False , False ],
24+ [False , "TRTLLM" , True , False , False , False , True , False ],
25+ [False , "TRTLLM" , True , False , False , False , False , False ],
26+ [True , "FLASHINFER" , True , False , False , False , True , False ],
27+ [False , "FLASHINFER" , True , False , False , False , True , False ],
28+ [False , "TRTLLM" , False , True , True , False , True , False ],
29+ [True , "TRTLLM" , False , True , True , False , True , False ],
30+ [True , "TRTLLM" , True , False , True , True , True , False ],
31+ [True , "TRTLLM" , True , False , True , False , True , False ],
3232 # TODO: nvbugs/5461761
33- # [True, "TRTLLM", True, False, False, True, True],
34- [True , "TRTLLM" , False , False , False , False , True ],
35- [False , "TRTLLM" , False , False , False , False , True ],
36- [True , "TRTLLM" , False , False , False , False , False ],
37- [False , "TRTLLM" , False , False , False , False , False ],
38- [True , "TRTLLM" , False , False , False , True , True ],
39- [True , "TRTLLM" , False , False , False , True , False ],
40- [True , "FLASHINFER" , False , False , False , False , True ],
41- [False , "FLASHINFER" , False , False , False , False , True ],
33+ # [True, "TRTLLM", True, False, False, True, True, False],
34+ [True , "TRTLLM" , False , False , False , False , True , False ],
35+ [False , "TRTLLM" , False , False , False , False , True , False ],
36+ [True , "TRTLLM" , False , False , False , False , False , False ],
37+ [False , "TRTLLM" , False , False , False , False , False , False ],
38+ [True , "TRTLLM" , False , False , False , True , True , False ],
39+ [True , "TRTLLM" , False , False , False , True , False , False ],
40+ [True , "FLASHINFER" , False , False , False , False , True , False ],
41+ [False , "FLASHINFER" , False , False , False , False , True , False ],
42+ [True , "TRTLLM" , True , True , True , True , True , True ],
4243 ])
4344@pytest .mark .high_cuda_memory
4445def test_llama_eagle3 (use_cuda_graph : bool , attn_backend : str ,
4546 disable_overlap_scheduler : bool , enable_block_reuse : bool ,
4647 use_one_model : bool , enable_chunked_prefill : bool ,
47- use_chain_drafter : bool ):
48+ use_chain_drafter : bool , fp8_target : bool ):
4849 # Eagle3 one model works with overlap scheduler and block reuse.
4950 total_mem_gb = torch .cuda .get_device_properties (0 ).total_memory / 1e9
5051 if total_mem_gb < 35 :
@@ -53,13 +54,18 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
5354 models_path = llm_models_root ()
5455 eagle_model_dir = f"{ models_path } /EAGLE3-LLaMA3.1-Instruct-8B"
5556 target_model_dir = f"{ models_path } /llama-3.1-model/Llama-3.1-8B-Instruct"
57+ kv_cache_dtype = 'auto'
58+ if fp8_target :
59+ target_model_dir = f"{ models_path } /llama-3.1-model/Llama-3.1-8B-Instruct-FP8"
60+ kv_cache_dtype = 'fp8'
5661
5762 # bs > 1 gives non-deterministic when doing IFB. There are slight chances
5863 # that ref and spec does not match 100%
5964 max_batch_size = 1
6065 max_draft_len = 4
6166 kv_cache_config = KvCacheConfig (enable_block_reuse = enable_block_reuse ,
62- max_tokens = 8192 )
67+ max_tokens = 8192 ,
68+ dtype = kv_cache_dtype )
6369 cuda_graph_config = CudaGraphConfig (
6470 batch_sizes = [1 ]) if use_cuda_graph else None
6571
0 commit comments