22import os
33import sys
44import tempfile
5- import unittest
65from pathlib import Path
76from unittest .mock import patch
87
@@ -24,40 +23,32 @@ def enforce_single_worker(monkeypatch):
2423
2524
2625@pytest .mark .parametrize (
27- "use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill,use_chain_drafter,multi_batch,attention_dp" ,
26+ "use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill,use_chain_drafter,multi_batch,attention_dp,fp8_target " ,
2827 [
29- [True , "TRTLLM" , True , False , False , False , True , False , False ],
30- [True , "TRTLLM" , True , False , False , False , False , False , False ],
31- [False , "TRTLLM" , True , False , False , False , True , False , False ],
32- [False , "TRTLLM" , True , False , False , False , False , False , False ],
33- [True , "FLASHINFER" , True , False , False , False , True , False , False ],
34- [False , "FLASHINFER" , True , False , False , False , True , False , False ],
35- [False , "TRTLLM" , False , True , True , False , True , False , False ],
36- [True , "TRTLLM" , False , True , True , False , True , False , False ],
37- [True , "TRTLLM" , True , False , True , True , True , False , False ],
38- [True , "TRTLLM" , True , False , True , False , True , False , False ],
28+ [True , "TRTLLM" , True , False , False , False , True , False , False , False ],
29+ [True , "TRTLLM" , True , False , False , False , False , False , False , False ],
30+ [False , "TRTLLM" , True , False , False , False , True , False , False , False ],
31+ [False , "TRTLLM" , True , False , False , False , False , False , False , False ],
32+ [True , "FLASHINFER" , True , False , False , False , True , False , False , False ],
33+ [False , "FLASHINFER" , True , False , False , False , True , False , False , False ],
34+ [False , "TRTLLM" , False , True , True , False , True , False , False , False ],
35+ [True , "TRTLLM" , False , True , True , False , True , False , False , False ],
36+ [True , "TRTLLM" , True , False , True , True , True , False , False , False ],
37+ [True , "TRTLLM" , True , False , True , False , True , False , False , False ],
3938 # TODO: nvbugs/5461761
40- # [True, "TRTLLM", True, False, False, True, True, False],
41- [True , "TRTLLM" , False , False , False , False , True , False , False ],
42- [False , "TRTLLM" , False , False , False , False , True , False , False ],
43- [True , "TRTLLM" , False , False , False , False , False , True , False ],
44- [True , "TRTLLM" , False , False , False , False , False , True , True ],
45- [False , "TRTLLM" , False , False , False , False , False , True , False ],
46- [True , "TRTLLM" , False , False , False , False , True , True , False ],
47- [False , "TRTLLM" , False , False , False , False , True , True , False ],
48- [True , "TRTLLM" , False , False , False , False , False , False , False ],
49- [False , "TRTLLM" , False , False , False , False , False , False , False ],
50- [True , "TRTLLM" , False , False , False , True , True , False , False ],
51- [True , "TRTLLM" , False , False , False , True , False , False , False ],
52- [True , "FLASHINFER" , False , False , False , False , True , False , False ],
53- [False , "FLASHINFER" , False , False , False , False , True , False , False ],
39+ # [True, "TRTLLM", True, False, False, True, True, False, False, False],
40+ [True , "TRTLLM" , False , False , False , False , True , False , False , False ],
41+ [False , "TRTLLM" , False , False , False , False , True , False , False , False ],
42+ [True , "TRTLLM" , False , False , False , False , False , True , False , False ],
43+ [True , "TRTLLM" , False , False , False , False , False , True , True , False ],
44+ [True , "TRTLLM" , False , True , True , True , True , True , True , True ],
5445 ])
5546@pytest .mark .high_cuda_memory
5647def test_llama_eagle3 (use_cuda_graph : bool , attn_backend : str ,
5748 disable_overlap_scheduler : bool , enable_block_reuse : bool ,
5849 use_one_model : bool , enable_chunked_prefill : bool ,
5950 use_chain_drafter : bool , multi_batch : bool ,
60- attention_dp : bool , request ):
51+ attention_dp : bool , fp8_target : bool , request ):
6152 # Use enforce_single_worker fixture only when use_chain_drafter is False.
6253 # Otherwise, we can't modify the returned value of _get_allow_chain_drafter in multiprocessing.
6354 if not use_chain_drafter :
@@ -71,6 +62,8 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
7162 models_path = llm_models_root ()
7263 eagle_model_dir = f"{ models_path } /EAGLE3-LLaMA3.1-Instruct-8B"
7364 target_model_dir = f"{ models_path } /llama-3.1-model/Llama-3.1-8B-Instruct"
65+ if fp8_target :
66+ target_model_dir = f"{ models_path } /llama-3.1-model/Llama-3.1-8B-Instruct-FP8"
7467
7568 # Mock _get_allow_chain_drafter to return False when use_chain_drafter is False
7669 if not use_chain_drafter :
@@ -89,6 +82,8 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
8982 max_draft_len = 4
9083 kv_cache_config = KvCacheConfig (enable_block_reuse = enable_block_reuse ,
9184 max_tokens = 8192 )
85+ if fp8_target :
86+ kv_cache_config .dtype = 'fp8'
9287 cuda_graph_config = CudaGraphConfig (
9388 batch_sizes = [i for i in range (1 , max_batch_size +
9489 1 )]) if use_cuda_graph else None
@@ -169,9 +164,11 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
169164 generated_text_ref = [result .outputs [0 ].text for result in results_ref ]
170165 llm_ref .shutdown ()
171166
172- for text_spec , text_ref in zip (generated_text_spec , generated_text_ref ):
173- # The spec decode algorithm currently guarantees identical results
174- assert text_spec == text_ref
167+ if not fp8_target :
168+ for text_spec , text_ref in zip (generated_text_spec ,
169+ generated_text_ref ):
170+ # The spec decode algorithm currently guarantees identical results
171+ assert text_spec == text_ref
175172
176173
177174def test_deepseek_eagle3 ():
@@ -377,6 +374,7 @@ def test_multi_eagle3(use_one_model: bool):
377374 pass
378375
379376
377+ < << << << HEAD
380378@pytest .mark .parametrize ("disable_overlap_scheduler" , [True , False ])
381379def test_eagle3_cuda_graph_padding (disable_overlap_scheduler : bool ):
382380 """Test CUDA graph padding with 3 requests and max_batch_size=4.
0 commit comments