22import os
33import sys
44import tempfile
5- import unittest
65from pathlib import Path
76from unittest .mock import patch
87
@@ -24,38 +23,40 @@ def enforce_single_worker(monkeypatch):
2423
2524
2625@pytest .mark .parametrize (
27- "use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill,use_chain_drafter,multi_batch" ,
26+ "use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill,use_chain_drafter,multi_batch,fp8_target " ,
2827 [
29- [True , "TRTLLM" , True , False , False , False , True , False ],
30- [True , "TRTLLM" , True , False , False , False , False , False ],
31- [False , "TRTLLM" , True , False , False , False , True , False ],
32- [False , "TRTLLM" , True , False , False , False , False , False ],
33- [True , "FLASHINFER" , True , False , False , False , True , False ],
34- [False , "FLASHINFER" , True , False , False , False , True , False ],
35- [False , "TRTLLM" , False , True , True , False , True , False ],
36- [True , "TRTLLM" , False , True , True , False , True , False ],
37- [True , "TRTLLM" , True , False , True , True , True , False ],
38- [True , "TRTLLM" , True , False , True , False , True , False ],
28+ [True , "TRTLLM" , True , False , False , False , True , False , False ],
29+ [True , "TRTLLM" , True , False , False , False , False , False , False ],
30+ [False , "TRTLLM" , True , False , False , False , True , False , False ],
31+ [False , "TRTLLM" , True , False , False , False , False , False , False ],
32+ [True , "FLASHINFER" , True , False , False , False , True , False , False ],
33+ [False , "FLASHINFER" , True , False , False , False , True , False , False ],
34+ [False , "TRTLLM" , False , True , True , False , True , False , False ],
35+ [True , "TRTLLM" , False , True , True , False , True , False , False ],
36+ [True , "TRTLLM" , True , False , True , True , True , False , False ],
37+ [True , "TRTLLM" , True , False , True , False , True , False , False ],
3938 # TODO: nvbugs/5461761
4039 # [True, "TRTLLM", True, False, False, True, True, False],
41- [True , "TRTLLM" , False , False , False , False , True , False ],
42- [False , "TRTLLM" , False , False , False , False , True , False ],
43- [True , "TRTLLM" , False , False , False , False , False , True ],
44- [False , "TRTLLM" , False , False , False , False , False , True ],
45- [True , "TRTLLM" , False , False , False , False , True , True ],
46- [False , "TRTLLM" , False , False , False , False , True , True ],
47- [True , "TRTLLM" , False , False , False , False , False , False ],
48- [False , "TRTLLM" , False , False , False , False , False , False ],
49- [True , "TRTLLM" , False , False , False , True , True , False ],
50- [True , "TRTLLM" , False , False , False , True , False , False ],
51- [True , "FLASHINFER" , False , False , False , False , True , False ],
52- [False , "FLASHINFER" , False , False , False , False , True , False ],
40+ [True , "TRTLLM" , False , False , False , False , True , False , False ],
41+ [False , "TRTLLM" , False , False , False , False , True , False , False ],
42+ [True , "TRTLLM" , False , False , False , False , False , True , False ],
43+ [False , "TRTLLM" , False , False , False , False , False , True , False ],
44+ [True , "TRTLLM" , False , False , False , False , True , True , False ],
45+ [False , "TRTLLM" , False , False , False , False , True , True , False ],
46+ [True , "TRTLLM" , False , False , False , False , False , False , False ],
47+ [False , "TRTLLM" , False , False , False , False , False , False , False ],
48+ [True , "TRTLLM" , False , False , False , True , True , False , False ],
49+ [True , "TRTLLM" , False , False , False , True , False , False , False ],
50+ [True , "FLASHINFER" , False , False , False , False , True , False , False ],
51+ [False , "FLASHINFER" , False , False , False , False , True , False , False ],
52+ [True , "TRTLLM" , False , True , True , True , True , True , True ],
5353 ])
5454@pytest .mark .high_cuda_memory
5555def test_llama_eagle3 (use_cuda_graph : bool , attn_backend : str ,
5656 disable_overlap_scheduler : bool , enable_block_reuse : bool ,
5757 use_one_model : bool , enable_chunked_prefill : bool ,
58- use_chain_drafter : bool , multi_batch : bool , request ):
58+ use_chain_drafter : bool , multi_batch : bool ,
59+ fp8_target : bool , request ):
5960 # Use enforce_single_worker fixture only when use_chain_drafter is False.
6061 # Otherwise, we can't modify the returned value of _get_allow_chain_drafter in multiprocessing.
6162 if not use_chain_drafter :
@@ -69,6 +70,8 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
6970 models_path = llm_models_root ()
7071 eagle_model_dir = f"{ models_path } /EAGLE3-LLaMA3.1-Instruct-8B"
7172 target_model_dir = f"{ models_path } /llama-3.1-model/Llama-3.1-8B-Instruct"
73+ if fp8_target :
74+ target_model_dir = f"{ models_path } /llama-3.1-model/Llama-3.1-8B-Instruct-FP8"
7275
7376 # Mock _get_allow_chain_drafter to return False when use_chain_drafter is False
7477 if not use_chain_drafter :
@@ -87,6 +90,8 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
8790 max_draft_len = 4
8891 kv_cache_config = KvCacheConfig (enable_block_reuse = enable_block_reuse ,
8992 max_tokens = 8192 )
93+ if fp8_target :
94+ kv_cache_config .dtype = 'fp8'
9095 cuda_graph_config = CudaGraphConfig (
9196 batch_sizes = [i for i in range (1 , max_batch_size +
9297 1 )]) if use_cuda_graph else None
@@ -166,9 +171,11 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
166171 generated_text_ref = [result .outputs [0 ].text for result in results_ref ]
167172 llm_ref .shutdown ()
168173
169- for text_spec , text_ref in zip (generated_text_spec , generated_text_ref ):
170- # The spec decode algorithm currently guarantees identical results
171- assert text_spec == text_ref
174+ if not fp8_target :
175+ for text_spec , text_ref in zip (generated_text_spec ,
176+ generated_text_ref ):
177+ # The spec decode algorithm currently guarantees identical results
178+ assert text_spec == text_ref
172179
173180
174181def test_deepseek_eagle3 ():
@@ -374,5 +381,8 @@ def test_multi_eagle3(use_one_model: bool):
374381 pass
375382
376383
377- if __name__ == "__main__" :
378- unittest .main ()
384+ # if __name__ == "__main__":
385+ # # unittest.main()
386+
387+ # # "use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill,use_chain_drafter,multi_batch",
388+ # # test_llama_eagle3(True, "TRTLLM", False, True, True, True, False, False)
0 commit comments