Skip to content
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,14 @@ bool DecoderXQAImplJIT::shouldUse(XQAParams const& umbrellaXQAParams, bool forCo
bool hasPerfGain = mayHavePerfGain(xqaParams);
if (!hasPerfGain)
{
if (xqaParams.kv_cache_data_type == DATA_TYPE_E4M3
&& (xqaParams.data_type == DATA_TYPE_BF16 || xqaParams.data_type == DATA_TYPE_FP16))
{
TLLM_LOG_DEBUG(
"JIT XQA is selected in the generation phase for fp16/bf16 input and e4m3 kv cache because MMHA "
"does not support this combination.");
return true;
}
TLLM_LOG_DEBUG("JIT XQA is not used: maybe no performance gain");
return false;
}
Expand Down
5 changes: 5 additions & 0 deletions tensorrt_llm/_torch/attention_backend/trtllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1211,6 +1211,7 @@ def forward(
output_sf: Optional[torch.Tensor] = None,
attention_sinks: Optional[torch.Tensor] = None,
chunked_prefill_buffer_batch_size: int = 1,
fp8_fmha_for_eagle3: bool = False,
**kwargs,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
assert isinstance(
Expand Down Expand Up @@ -1293,6 +1294,10 @@ def forward(
if use_nvfp4_output:
# Use UINT8 as the container dtype for NVFP4.
out_dtype = torch.uint8
# elif fp8_fmha_for_eagle3:
elif self.has_fp8_kv_cache and not self.has_fp8_qdq and out_scale is not None:
# Force to use FP8 FMHA for (eagle3 + FP8 target model + BF16/FP16 draft model) in draft layers
out_dtype = torch.float8_e4m3fn
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that is said, this is not true for all cases. on Blackwell, the fp8 fmha kernels can output bf16 directly. In which case, we want to avoid explicitly doing the conversion after attention op.

we better add a flag or something (it is not clear to me yet), which is false by default, so that it won't break other workflows.

elif (self.has_fp8_qdq or self.has_nvfp4 or self.has_fp8_block_wise
or self.has_fp8_rowwise
or self.has_w4a8_nvfp4_fp8) and (self.has_fp8_kv_cache
Expand Down
15 changes: 14 additions & 1 deletion tensorrt_llm/_torch/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,14 @@ def _attn_impl(
if mrope_position_deltas is not None:
mrope_config["mrope_position_deltas"] = mrope_position_deltas

# Be forced to use FP8 FMHA for BF16/FP16 model with FP8 KV cache (e.g. eagle3 + FP8 target model + BF16/FP16 draft model)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems too specific (more like a WAR). @yuxianq do you have any insights about this ? thanks.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that it is too specific. The purpose of this PR is to add a way to explicitly control whether we use fp8 fmha outside attention op. How about add a force_fp8_fmha to attention (false by default) and only enable it in eagle3 case? We don't need to add new fields to the common AttentionOp.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that makes sense to me. Thanks!

forced_to_fp8_fmha = not self.has_quant_scale and self.quant_config is not None and self.quant_config.layer_quant_mode.has_fp8_kv_cache(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above. we can add the conversion kernel inside the attention op (https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/common/attentionOp.cpp), so that if the output dtype is not support on Hopper/Ampere (using fmha_v2), we can invoke the conversion kernel. Exposing the logic outside the attention will complicate the design as this is only needed by fmha_v2.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @PerkzZheng I have moved the logic to attentionOp, and have distinguished the behaviors of Blackwell and pre-Blackwell. CI failure has been fixed locally. Could you please review it again? Thanks.

) and attn_metadata.num_contexts != 0
if forced_to_fp8_fmha:
out_scale = torch.tensor([1.0],
dtype=torch.float32,
device=q.device)

attn_output = self.attn.forward(
q,
k,
Expand All @@ -425,7 +433,12 @@ def _attn_impl(
assert len(
attn_output
) == 2, "attn_output should be a tuple of (output, output_sf)"
return attn_output[0], attn_output[1]
if forced_to_fp8_fmha:
return attn_output[0].to(q.dtype), attn_output[1]
else:
return attn_output[0], attn_output[1]
if forced_to_fp8_fmha:
return attn_output.to(q.dtype), None
return attn_output, None

def forward_impl(
Expand Down
91 changes: 62 additions & 29 deletions tests/unittest/_torch/speculative/test_eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,40 +23,68 @@ def enforce_single_worker(monkeypatch):


@pytest.mark.parametrize(
"use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill,use_chain_drafter,multi_batch,attention_dp",
"use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill,use_chain_drafter,multi_batch,attention_dp,fp8_target",
[
[True, "TRTLLM", True, False, False, False, True, False, False],
[True, "TRTLLM", True, False, False, False, False, False, False],
[False, "TRTLLM", True, False, False, False, True, False, False],
[False, "TRTLLM", True, False, False, False, False, False, False],
[True, "FLASHINFER", True, False, False, False, True, False, False],
[False, "FLASHINFER", True, False, False, False, True, False, False],
[False, "TRTLLM", False, True, True, False, True, False, False],
[True, "TRTLLM", False, True, True, False, True, False, False],
[True, "TRTLLM", True, False, True, True, True, False, False],
[True, "TRTLLM", True, False, True, False, True, False, False],
[True, "TRTLLM", True, False, False, False, True, False, False, False],
[True, "TRTLLM", True, False, False, False, False, False, False, False],
[False, "TRTLLM", True, False, False, False, True, False, False, False],
[
False, "TRTLLM", True, False, False, False, False, False, False,
False
],
[
True, "FLASHINFER", True, False, False, False, True, False, False,
False
],
[
False, "FLASHINFER", True, False, False, False, True, False, False,
False
],
[False, "TRTLLM", False, True, True, False, True, False, False, False],
[True, "TRTLLM", False, True, True, False, True, False, False, False],
[True, "TRTLLM", True, False, True, True, True, False, False, False],
[True, "TRTLLM", True, False, True, False, True, False, False, False],
# TODO: nvbugs/5461761
# [True, "TRTLLM", True, False, False, True, True, False],
[True, "TRTLLM", False, False, False, False, True, False, False],
[False, "TRTLLM", False, False, False, False, True, False, False],
[True, "TRTLLM", False, False, False, False, False, True, False],
[True, "TRTLLM", False, False, False, False, False, True, True],
[False, "TRTLLM", False, False, False, False, False, True, False],
[True, "TRTLLM", False, False, False, False, True, True, False],
[False, "TRTLLM", False, False, False, False, True, True, False],
[True, "TRTLLM", False, False, False, False, False, False, False],
[False, "TRTLLM", False, False, False, False, False, False, False],
[True, "TRTLLM", False, False, False, True, True, False, False],
[True, "TRTLLM", False, False, False, True, False, False, False],
[True, "FLASHINFER", False, False, False, False, True, False, False],
[False, "FLASHINFER", False, False, False, False, True, False, False],
# [True, "TRTLLM", True, False, False, True, True, False, False, False],
[True, "TRTLLM", False, False, False, False, True, False, False, False],
[
False, "TRTLLM", False, False, False, False, True, False, False,
False
],
[True, "TRTLLM", False, False, False, False, False, True, False, False],
[True, "TRTLLM", False, False, False, False, False, True, True, False],
[
False, "TRTLLM", False, False, False, False, False, True, False,
False
],
[True, "TRTLLM", False, False, False, False, True, True, False, False],
[False, "TRTLLM", False, False, False, False, True, True, False, False],
[
True, "TRTLLM", False, False, False, False, False, False, False,
False
],
[
False, "TRTLLM", False, False, False, False, False, False, False,
False
],
[True, "TRTLLM", False, False, False, True, True, False, False, False],
[True, "TRTLLM", False, False, False, True, False, False, False, False],
[
True, "FLASHINFER", False, False, False, False, True, False, False,
False
],
[
False, "FLASHINFER", False, False, False, False, True, False, False,
False
],
[True, "TRTLLM", False, True, True, True, True, True, True, True],
])
@pytest.mark.high_cuda_memory
def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
disable_overlap_scheduler: bool, enable_block_reuse: bool,
use_one_model: bool, enable_chunked_prefill: bool,
use_chain_drafter: bool, multi_batch: bool,
attention_dp: bool, request):
attention_dp: bool, fp8_target: bool, request):
# Eagle3 one model works with overlap scheduler and block reuse.
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
if total_mem_gb < 35:
Expand All @@ -65,13 +93,17 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
models_path = llm_models_root()
eagle_model_dir = f"{models_path}/EAGLE3-LLaMA3.1-Instruct-8B"
target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct"
if fp8_target:
target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct-FP8"

# bs > 1 gives non-deterministic when doing IFB. There are slight chances
# that ref and spec does not match 100%
max_batch_size = 4 if multi_batch else 1
max_draft_len = 4
kv_cache_config = KvCacheConfig(enable_block_reuse=enable_block_reuse,
max_tokens=8192)
if fp8_target:
kv_cache_config.dtype = 'fp8'
cuda_graph_config = CudaGraphConfig(
batch_sizes=[i for i in range(1, max_batch_size +
1)]) if use_cuda_graph else None
Expand Down Expand Up @@ -151,9 +183,10 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
generated_text_ref = [result.outputs[0].text for result in results_ref]
llm_ref.shutdown()

for text_spec, text_ref in zip(generated_text_spec, generated_text_ref):
# The spec decode algorithm currently guarantees identical results
assert text_spec == text_ref
if not fp8_target:
for text_spec, text_ref in zip(generated_text_spec, generated_text_ref):
# The spec decode algorithm currently guarantees identical results
assert text_spec == text_ref


def test_deepseek_eagle3():
Expand Down