Skip to content

Commit 6a2ecbc

Browse files
committed
fix eagle3 fp8 chunk
Signed-off-by: Dylan Chen <[email protected]>
1 parent b0cb9ca commit 6a2ecbc

File tree

4 files changed

+69
-31
lines changed

4 files changed

+69
-31
lines changed

cpp/tensorrt_llm/kernels/xqaDispatcher.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,16 @@ bool XqaDispatcher::shouldUse(XQAParams const& params)
243243

244244
return true;
245245
}
246+
247+
if (params.kv_cache_data_type == DATA_TYPE_E4M3
248+
&& (params.data_type == DATA_TYPE_BF16 || params.data_type == DATA_TYPE_FP16))
249+
{
250+
TLLM_LOG_DEBUG(
251+
"XQA kernels are selected in the generation phase for fp16/bf16 input and e4m3 kv cache because MMHA does "
252+
"not support this combination.");
253+
return true;
254+
}
255+
246256
return mDecoderXqaRunner->shouldUse(params, /*forConfigurePlugin=*/false);
247257
}
248258

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1211,6 +1211,7 @@ def forward(
12111211
output_sf: Optional[torch.Tensor] = None,
12121212
attention_sinks: Optional[torch.Tensor] = None,
12131213
chunked_prefill_buffer_batch_size: int = 1,
1214+
fp8_fmha_for_eagle3: bool = False,
12141215
**kwargs,
12151216
) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
12161217
assert isinstance(
@@ -1293,6 +1294,10 @@ def forward(
12931294
if use_nvfp4_output:
12941295
# Use UINT8 as the container dtype for NVFP4.
12951296
out_dtype = torch.uint8
1297+
# elif fp8_fmha_for_eagle3:
1298+
elif self.has_fp8_kv_cache and not self.has_fp8_qdq and out_scale is not None:
1299+
# Force to use FP8 FMHA for (eagle3 + FP8 target model + BF16/FP16 draft model) in draft layers
1300+
out_dtype = torch.float8_e4m3fn
12961301
elif (self.has_fp8_qdq or self.has_nvfp4 or self.has_fp8_block_wise
12971302
or self.has_fp8_rowwise
12981303
or self.has_w4a8_nvfp4_fp8) and (self.has_fp8_kv_cache

tensorrt_llm/_torch/modules/attention.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,14 @@ def _attn_impl(
404404
if mrope_position_deltas is not None:
405405
mrope_config["mrope_position_deltas"] = mrope_position_deltas
406406

407+
# Be forced to use FP8 FMHA for BF16/FP16 model with FP8 KV cache (e.g. eagle3 + FP8 target model + BF16/FP16 draft model)
408+
forced_to_fp8_fmha = not self.has_quant_scale and self.quant_config is not None and self.quant_config.layer_quant_mode.has_fp8_kv_cache(
409+
) and attn_metadata.num_contexts != 0
410+
if forced_to_fp8_fmha:
411+
out_scale = torch.tensor([1.0],
412+
dtype=torch.float32,
413+
device=q.device)
414+
407415
attn_output = self.attn.forward(
408416
q,
409417
k,
@@ -425,7 +433,12 @@ def _attn_impl(
425433
assert len(
426434
attn_output
427435
) == 2, "attn_output should be a tuple of (output, output_sf)"
428-
return attn_output[0], attn_output[1]
436+
if forced_to_fp8_fmha:
437+
return attn_output[0].to(q.dtype), attn_output[1]
438+
else:
439+
return attn_output[0], attn_output[1]
440+
if forced_to_fp8_fmha:
441+
return attn_output.to(q.dtype), None
429442
return attn_output, None
430443

431444
def forward_impl(

tests/unittest/_torch/speculative/test_eagle3.py

Lines changed: 40 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import os
33
import sys
44
import tempfile
5-
import unittest
65
from pathlib import Path
76
from unittest.mock import patch
87

@@ -24,38 +23,40 @@ def enforce_single_worker(monkeypatch):
2423

2524

2625
@pytest.mark.parametrize(
27-
"use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill,use_chain_drafter,multi_batch",
26+
"use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill,use_chain_drafter,multi_batch,fp8_target",
2827
[
29-
[True, "TRTLLM", True, False, False, False, True, False],
30-
[True, "TRTLLM", True, False, False, False, False, False],
31-
[False, "TRTLLM", True, False, False, False, True, False],
32-
[False, "TRTLLM", True, False, False, False, False, False],
33-
[True, "FLASHINFER", True, False, False, False, True, False],
34-
[False, "FLASHINFER", True, False, False, False, True, False],
35-
[False, "TRTLLM", False, True, True, False, True, False],
36-
[True, "TRTLLM", False, True, True, False, True, False],
37-
[True, "TRTLLM", True, False, True, True, True, False],
38-
[True, "TRTLLM", True, False, True, False, True, False],
28+
[True, "TRTLLM", True, False, False, False, True, False, False],
29+
[True, "TRTLLM", True, False, False, False, False, False, False],
30+
[False, "TRTLLM", True, False, False, False, True, False, False],
31+
[False, "TRTLLM", True, False, False, False, False, False, False],
32+
[True, "FLASHINFER", True, False, False, False, True, False, False],
33+
[False, "FLASHINFER", True, False, False, False, True, False, False],
34+
[False, "TRTLLM", False, True, True, False, True, False, False],
35+
[True, "TRTLLM", False, True, True, False, True, False, False],
36+
[True, "TRTLLM", True, False, True, True, True, False, False],
37+
[True, "TRTLLM", True, False, True, False, True, False, False],
3938
# TODO: nvbugs/5461761
4039
# [True, "TRTLLM", True, False, False, True, True, False],
41-
[True, "TRTLLM", False, False, False, False, True, False],
42-
[False, "TRTLLM", False, False, False, False, True, False],
43-
[True, "TRTLLM", False, False, False, False, False, True],
44-
[False, "TRTLLM", False, False, False, False, False, True],
45-
[True, "TRTLLM", False, False, False, False, True, True],
46-
[False, "TRTLLM", False, False, False, False, True, True],
47-
[True, "TRTLLM", False, False, False, False, False, False],
48-
[False, "TRTLLM", False, False, False, False, False, False],
49-
[True, "TRTLLM", False, False, False, True, True, False],
50-
[True, "TRTLLM", False, False, False, True, False, False],
51-
[True, "FLASHINFER", False, False, False, False, True, False],
52-
[False, "FLASHINFER", False, False, False, False, True, False],
40+
[True, "TRTLLM", False, False, False, False, True, False, False],
41+
[False, "TRTLLM", False, False, False, False, True, False, False],
42+
[True, "TRTLLM", False, False, False, False, False, True, False],
43+
[False, "TRTLLM", False, False, False, False, False, True, False],
44+
[True, "TRTLLM", False, False, False, False, True, True, False],
45+
[False, "TRTLLM", False, False, False, False, True, True, False],
46+
[True, "TRTLLM", False, False, False, False, False, False, False],
47+
[False, "TRTLLM", False, False, False, False, False, False, False],
48+
[True, "TRTLLM", False, False, False, True, True, False, False],
49+
[True, "TRTLLM", False, False, False, True, False, False, False],
50+
[True, "FLASHINFER", False, False, False, False, True, False, False],
51+
[False, "FLASHINFER", False, False, False, False, True, False, False],
52+
[True, "TRTLLM", False, True, True, True, True, True, True],
5353
])
5454
@pytest.mark.high_cuda_memory
5555
def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
5656
disable_overlap_scheduler: bool, enable_block_reuse: bool,
5757
use_one_model: bool, enable_chunked_prefill: bool,
58-
use_chain_drafter: bool, multi_batch: bool, request):
58+
use_chain_drafter: bool, multi_batch: bool,
59+
fp8_target: bool, request):
5960
# Use enforce_single_worker fixture only when use_chain_drafter is False.
6061
# Otherwise, we can't modify the returned value of _get_allow_chain_drafter in multiprocessing.
6162
if not use_chain_drafter:
@@ -69,6 +70,8 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
6970
models_path = llm_models_root()
7071
eagle_model_dir = f"{models_path}/EAGLE3-LLaMA3.1-Instruct-8B"
7172
target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct"
73+
if fp8_target:
74+
target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct-FP8"
7275

7376
# Mock _get_allow_chain_drafter to return False when use_chain_drafter is False
7477
if not use_chain_drafter:
@@ -87,6 +90,8 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
8790
max_draft_len = 4
8891
kv_cache_config = KvCacheConfig(enable_block_reuse=enable_block_reuse,
8992
max_tokens=8192)
93+
if fp8_target:
94+
kv_cache_config.dtype = 'fp8'
9095
cuda_graph_config = CudaGraphConfig(
9196
batch_sizes=[i for i in range(1, max_batch_size +
9297
1)]) if use_cuda_graph else None
@@ -166,9 +171,11 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
166171
generated_text_ref = [result.outputs[0].text for result in results_ref]
167172
llm_ref.shutdown()
168173

169-
for text_spec, text_ref in zip(generated_text_spec, generated_text_ref):
170-
# The spec decode algorithm currently guarantees identical results
171-
assert text_spec == text_ref
174+
if not fp8_target:
175+
for text_spec, text_ref in zip(generated_text_spec,
176+
generated_text_ref):
177+
# The spec decode algorithm currently guarantees identical results
178+
assert text_spec == text_ref
172179

173180

174181
def test_deepseek_eagle3():
@@ -374,5 +381,8 @@ def test_multi_eagle3(use_one_model: bool):
374381
pass
375382

376383

377-
if __name__ == "__main__":
378-
unittest.main()
384+
# if __name__ == "__main__":
385+
# # unittest.main()
386+
387+
# # "use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill,use_chain_drafter,multi_batch",
388+
# # test_llama_eagle3(True, "TRTLLM", False, True, True, True, False, False)

0 commit comments

Comments
 (0)