Skip to content

Commit 2c063a2

Browse files
committed
fix eagle3 fp8 chunk
Signed-off-by: Dylan Chen <[email protected]>
1 parent d3059db commit 2c063a2

File tree

4 files changed

+57
-31
lines changed

4 files changed

+57
-31
lines changed

cpp/tensorrt_llm/kernels/xqaDispatcher.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,16 @@ bool XqaDispatcher::shouldUse(XQAParams const& params)
243243

244244
return true;
245245
}
246+
247+
if (params.kv_cache_data_type == DATA_TYPE_E4M3
248+
&& (params.data_type == DATA_TYPE_BF16 || params.data_type == DATA_TYPE_FP16))
249+
{
250+
TLLM_LOG_DEBUG(
251+
"XQA kernels are selected in the generation phase for fp16/bf16 input and e4m3 kv cache because MMHA does "
252+
"not support this combination.");
253+
return true;
254+
}
255+
246256
return mDecoderXqaRunner->shouldUse(params, /*forConfigurePlugin=*/false);
247257
}
248258

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1211,6 +1211,7 @@ def forward(
12111211
output_sf: Optional[torch.Tensor] = None,
12121212
attention_sinks: Optional[torch.Tensor] = None,
12131213
chunked_prefill_buffer_batch_size: int = 1,
1214+
fp8_fmha_for_eagle3: bool = False,
12141215
**kwargs,
12151216
) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
12161217
assert isinstance(
@@ -1293,6 +1294,10 @@ def forward(
12931294
if use_nvfp4_output:
12941295
# Use UINT8 as the container dtype for NVFP4.
12951296
out_dtype = torch.uint8
1297+
# elif fp8_fmha_for_eagle3:
1298+
elif self.has_fp8_kv_cache and not self.has_fp8_qdq and out_scale is not None:
1299+
# Force to use FP8 FMHA for (eagle3 + FP8 target model + BF16/FP16 draft model) in draft layers
1300+
out_dtype = torch.float8_e4m3fn
12961301
elif (self.has_fp8_qdq or self.has_nvfp4 or self.has_fp8_block_wise
12971302
or self.has_fp8_rowwise
12981303
or self.has_w4a8_nvfp4_fp8) and (self.has_fp8_kv_cache

tensorrt_llm/_torch/modules/attention.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,14 @@ def _attn_impl(
404404
if mrope_position_deltas is not None:
405405
mrope_config["mrope_position_deltas"] = mrope_position_deltas
406406

407+
# Be forced to use FP8 FMHA for BF16/FP16 model with FP8 KV cache (e.g. eagle3 + FP8 target model + BF16/FP16 draft model)
408+
forced_to_fp8_fmha = not self.has_quant_scale and self.quant_config is not None and self.quant_config.layer_quant_mode.has_fp8_kv_cache(
409+
) and attn_metadata.num_contexts != 0
410+
if forced_to_fp8_fmha:
411+
out_scale = torch.tensor([1.0],
412+
dtype=torch.float32,
413+
device=q.device)
414+
407415
attn_output = self.attn.forward(
408416
q,
409417
k,
@@ -425,7 +433,12 @@ def _attn_impl(
425433
assert len(
426434
attn_output
427435
) == 2, "attn_output should be a tuple of (output, output_sf)"
428-
return attn_output[0], attn_output[1]
436+
if forced_to_fp8_fmha:
437+
return attn_output[0].to(q.dtype), attn_output[1]
438+
else:
439+
return attn_output[0], attn_output[1]
440+
if forced_to_fp8_fmha:
441+
return attn_output.to(q.dtype), None
429442
return attn_output, None
430443

431444
def forward_impl(

tests/unittest/_torch/speculative/test_eagle3.py

Lines changed: 28 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import os
33
import sys
44
import tempfile
5-
import unittest
65
from pathlib import Path
76
from unittest.mock import patch
87

@@ -24,40 +23,32 @@ def enforce_single_worker(monkeypatch):
2423

2524

2625
@pytest.mark.parametrize(
27-
"use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill,use_chain_drafter,multi_batch,attention_dp",
26+
"use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill,use_chain_drafter,multi_batch,attention_dp,fp8_target",
2827
[
29-
[True, "TRTLLM", True, False, False, False, True, False, False],
30-
[True, "TRTLLM", True, False, False, False, False, False, False],
31-
[False, "TRTLLM", True, False, False, False, True, False, False],
32-
[False, "TRTLLM", True, False, False, False, False, False, False],
33-
[True, "FLASHINFER", True, False, False, False, True, False, False],
34-
[False, "FLASHINFER", True, False, False, False, True, False, False],
35-
[False, "TRTLLM", False, True, True, False, True, False, False],
36-
[True, "TRTLLM", False, True, True, False, True, False, False],
37-
[True, "TRTLLM", True, False, True, True, True, False, False],
38-
[True, "TRTLLM", True, False, True, False, True, False, False],
28+
[True, "TRTLLM", True, False, False, False, True, False, False, False],
29+
[True, "TRTLLM", True, False, False, False, False, False, False, False],
30+
[False, "TRTLLM", True, False, False, False, True, False, False, False],
31+
[False, "TRTLLM", True, False, False, False, False, False, False, False],
32+
[True, "FLASHINFER", True, False, False, False, True, False, False, False],
33+
[False, "FLASHINFER", True, False, False, False, True, False, False, False],
34+
[False, "TRTLLM", False, True, True, False, True, False, False, False],
35+
[True, "TRTLLM", False, True, True, False, True, False, False, False],
36+
[True, "TRTLLM", True, False, True, True, True, False, False, False],
37+
[True, "TRTLLM", True, False, True, False, True, False, False, False],
3938
# TODO: nvbugs/5461761
40-
# [True, "TRTLLM", True, False, False, True, True, False],
41-
[True, "TRTLLM", False, False, False, False, True, False, False],
42-
[False, "TRTLLM", False, False, False, False, True, False, False],
43-
[True, "TRTLLM", False, False, False, False, False, True, False],
44-
[True, "TRTLLM", False, False, False, False, False, True, True],
45-
[False, "TRTLLM", False, False, False, False, False, True, False],
46-
[True, "TRTLLM", False, False, False, False, True, True, False],
47-
[False, "TRTLLM", False, False, False, False, True, True, False],
48-
[True, "TRTLLM", False, False, False, False, False, False, False],
49-
[False, "TRTLLM", False, False, False, False, False, False, False],
50-
[True, "TRTLLM", False, False, False, True, True, False, False],
51-
[True, "TRTLLM", False, False, False, True, False, False, False],
52-
[True, "FLASHINFER", False, False, False, False, True, False, False],
53-
[False, "FLASHINFER", False, False, False, False, True, False, False],
39+
# [True, "TRTLLM", True, False, False, True, True, False, False, False],
40+
[True, "TRTLLM", False, False, False, False, True, False, False, False],
41+
[False, "TRTLLM", False, False, False, False, True, False, False, False],
42+
[True, "TRTLLM", False, False, False, False, False, True, False, False],
43+
[True, "TRTLLM", False, False, False, False, False, True, True, False],
44+
[True, "TRTLLM", False, True, True, True, True, True, True, True],
5445
])
5546
@pytest.mark.high_cuda_memory
5647
def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
5748
disable_overlap_scheduler: bool, enable_block_reuse: bool,
5849
use_one_model: bool, enable_chunked_prefill: bool,
5950
use_chain_drafter: bool, multi_batch: bool,
60-
attention_dp: bool, request):
51+
attention_dp: bool, fp8_target: bool, request):
6152
# Use enforce_single_worker fixture only when use_chain_drafter is False.
6253
# Otherwise, we can't modify the returned value of _get_allow_chain_drafter in multiprocessing.
6354
if not use_chain_drafter:
@@ -71,6 +62,8 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
7162
models_path = llm_models_root()
7263
eagle_model_dir = f"{models_path}/EAGLE3-LLaMA3.1-Instruct-8B"
7364
target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct"
65+
if fp8_target:
66+
target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct-FP8"
7467

7568
# Mock _get_allow_chain_drafter to return False when use_chain_drafter is False
7669
if not use_chain_drafter:
@@ -89,6 +82,8 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
8982
max_draft_len = 4
9083
kv_cache_config = KvCacheConfig(enable_block_reuse=enable_block_reuse,
9184
max_tokens=8192)
85+
if fp8_target:
86+
kv_cache_config.dtype = 'fp8'
9287
cuda_graph_config = CudaGraphConfig(
9388
batch_sizes=[i for i in range(1, max_batch_size +
9489
1)]) if use_cuda_graph else None
@@ -169,9 +164,11 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
169164
generated_text_ref = [result.outputs[0].text for result in results_ref]
170165
llm_ref.shutdown()
171166

172-
for text_spec, text_ref in zip(generated_text_spec, generated_text_ref):
173-
# The spec decode algorithm currently guarantees identical results
174-
assert text_spec == text_ref
167+
if not fp8_target:
168+
for text_spec, text_ref in zip(generated_text_spec,
169+
generated_text_ref):
170+
# The spec decode algorithm currently guarantees identical results
171+
assert text_spec == text_ref
175172

176173

177174
def test_deepseek_eagle3():
@@ -377,6 +374,7 @@ def test_multi_eagle3(use_one_model: bool):
377374
pass
378375

379376

377+
<<<<<<< HEAD
380378
@pytest.mark.parametrize("disable_overlap_scheduler", [True, False])
381379
def test_eagle3_cuda_graph_padding(disable_overlap_scheduler: bool):
382380
"""Test CUDA graph padding with 3 requests and max_batch_size=4.

0 commit comments

Comments
 (0)