From f27502ec23796832698ae8bc343008a4c0fe401f Mon Sep 17 00:00:00 2001 From: Dylan Chen <191843203+DylanChen-NV@users.noreply.github.com> Date: Tue, 4 Nov 2025 12:49:41 +0000 Subject: [PATCH] fix eagle3 fp8 target model + bf16 draft model Signed-off-by: Dylan Chen <191843203+DylanChen-NV@users.noreply.github.com> --- cpp/kernels/fmha_v2/setup.py | 16 ++- .../decoderXQAImplJIT/decoderXQAImplJIT.cpp | 8 ++ .../_torch/speculative/test_eagle3.py | 112 ++++++++++++++++++ 3 files changed, 130 insertions(+), 6 deletions(-) diff --git a/cpp/kernels/fmha_v2/setup.py b/cpp/kernels/fmha_v2/setup.py index c1774e389e9..b1f877dcc33 100644 --- a/cpp/kernels/fmha_v2/setup.py +++ b/cpp/kernels/fmha_v2/setup.py @@ -3063,7 +3063,9 @@ def get_kernel_traits_code(specs_names): # 2. Hopper sm89 with e4m3/e4m3_fp32 dtype uses cubins for accuracy regressions (will be fixed). # You should set the condition `use_cubin_header` to false if you have modified the source codes of those kernels that use cubins. # This ensures that the kernels will be recompiled using the updated source code rather than relying on precompiled cubins. -def use_cubin_header(sm, head_size, dtype): +def use_cubin_header(sm, head_size, dtype, output_dtype=None): + if 'e4m3' in dtype and output_dtype in ['bf16', 'fp16']: + return False return (sm == 90 and head_size == 128) or (sm == 89 and 'e4m3' in dtype) @@ -3074,7 +3076,7 @@ def get_cubin_header(kernel_traits, specs_names): cubin_lens_dict = {} for kspec, fname, lname, kname in specs_names: if generate_cu_trtllm and not use_cubin_header( - kspec.sm, kspec.head_size, kspec.dtype): + kspec.sm, kspec.head_size, kspec.dtype, kspec.output_dtype): continue name = fname.replace('.', '_') data = 'extern unsigned char cubin_{name}_cubin[];'.format(name=name) @@ -3229,7 +3231,8 @@ def get_cubin_header(kernel_traits, specs_names): if generate_cu_trtllm: def get_lname_from_kname(kname: str) -> str: - if use_cubin_header(int(sm), int(head_size), prec.lower()): + if use_cubin_header(int(sm), int(head_size), prec.lower(), + output_prec.lower()): return 'nullptr' lname = kname.replace('_kernel', '') mask_types = [ @@ -3248,8 +3251,9 @@ def get_lname_from_kname(kname: str) -> str: {cubin_name}_len, \"{kname}\", {smem}, {threads}, {meta_unroll_step}, {attention_mask_type_value}, \ {attention_input_layout_value}, {is_il}, {is_flash_atten}, {is_warp_specialization}, {is_fp32_accu}, \ {is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}, {lname}}}\ -'''.format(**locals()) if use_cubin_header(int(sm), int(head_size), - prec.lower()) else '''\ +'''.format(**locals()) if use_cubin_header(int(sm), + int(head_size), prec.lower(), + output_prec.lower()) else '''\ {{ DATA_TYPE_{prec}, DATA_TYPE_{output_prec}, {seq_len}, {q_step}, {kv_step}, {head_size}, {head_size_v}, \ {sage_block_sizes[0]}, {sage_block_sizes[1]}, {sage_block_sizes[2]}, kSM_{sm}, nullptr, \ 0, \"{kname}\", {smem}, {threads}, {meta_unroll_step}, {attention_mask_type_value}, \ @@ -3791,7 +3795,7 @@ def enumerate_qgmma_flash_warpspec_kernels(specs, continue # for normal attention, we do not need return softmax for ws fp8 kernels currently. # also fp8 input and bf16 output is only needed for MLA kernel. - skip_combination = return_softmax or (output_dtype is not None) + skip_combination = return_softmax # for context mla, we need separate qkv as input layout when returning softmax. skip_mla_combination = return_softmax and input_layout != InputLayout.SEPARATE_Q_K_V if not skip_combination: diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp index 79e9694b95f..c2eb6257d67 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp @@ -124,6 +124,14 @@ bool DecoderXQAImplJIT::shouldUse(XQAParams const& umbrellaXQAParams, bool forCo bool hasPerfGain = mayHavePerfGain(xqaParams); if (!hasPerfGain) { + if (!xqaParams.is_fp8_output && xqaParams.kv_cache_data_type == DATA_TYPE_E4M3 + && (xqaParams.data_type == DATA_TYPE_BF16 || xqaParams.data_type == DATA_TYPE_FP16)) + { + TLLM_LOG_DEBUG( + "JIT XQA is selected in the generation phase for fp16/bf16 input and e4m3 kv cache because MMHA " + "does not support this combination."); + return true; + } TLLM_LOG_DEBUG("JIT XQA is not used: maybe no performance gain"); return false; } diff --git a/tests/unittest/_torch/speculative/test_eagle3.py b/tests/unittest/_torch/speculative/test_eagle3.py index f6b21ecbbd3..da95386390a 100644 --- a/tests/unittest/_torch/speculative/test_eagle3.py +++ b/tests/unittest/_torch/speculative/test_eagle3.py @@ -520,5 +520,117 @@ def test_eagle3_cdl_sampling(disable_overlap_scheduler: bool): llm_spec.shutdown() +@pytest.mark.parametrize( + "enable_block_reuse,use_one_model,enable_chunked_prefill,fp8_target", [ + [True, True, True, True], + ]) +@pytest.mark.high_cuda_memory +def test_qwen3_eagle3(enable_block_reuse: bool, use_one_model: bool, + enable_chunked_prefill: bool, fp8_target: bool): + # Eagle3 one model works with overlap scheduler and block reuse. + total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 + if total_mem_gb < 35: + pytest.skip("Not enough memory to load target + draft model") + + use_cuda_graph = True + attn_backend = "TRTLLM" + disable_overlap_scheduler = False + use_chain_drafter = True + multi_batch = False + attention_dp = False + + models_path = llm_models_root() + eagle_model_dir = "/ziqingc_large/03_Data/models/Zhi-Create-Qwen3-32B-Eagle3" # temp + target_model_dir = f"{models_path}/Qwen3/Qwen3-32B" + if fp8_target: + target_model_dir = f"{models_path}/Qwen3/Qwen3-32B-FP8/" + + # bs > 1 gives non-deterministic when doing IFB. There are slight chances + # that ref and spec does not match 100% + max_batch_size = 4 if multi_batch else 1 + max_draft_len = 3 + kv_cache_config = KvCacheConfig(enable_block_reuse=enable_block_reuse, + max_tokens=8192) + if fp8_target: + kv_cache_config.dtype = 'fp8' + cuda_graph_config = CudaGraphConfig( + batch_sizes=[i for i in range(1, max_batch_size + + 1)]) if use_cuda_graph else None + + llm_common_config = dict( + model=target_model_dir, + attn_backend=attn_backend, + disable_overlap_scheduler=disable_overlap_scheduler, + cuda_graph_config=cuda_graph_config, + max_batch_size=max_batch_size, + kv_cache_config=kv_cache_config, + enable_attention_dp=attention_dp, + max_seq_len=8192, + enable_chunked_prefill=enable_chunked_prefill, + ) + if enable_chunked_prefill: + # Use a small max_num_tokens so that the chunked prefill path gets exercised. + llm_common_config['max_num_tokens'] = 64 + + spec_config = EagleDecodingConfig( + max_draft_len=max_draft_len, + speculative_model_dir=eagle_model_dir, + eagle3_one_model=use_one_model, + ) + spec_config._allow_chain_drafter = use_chain_drafter + + # Create the LLM instance + llm_spec = LLM(**llm_common_config, speculative_config=spec_config) + + # Acceptance rate tests + if enable_chunked_prefill: + # Use a long prompt for chunked prefill tests. + prompts = [ + "The capital of France is a city of romance, art, fashion, and cuisine. Paris is a must-visit destination for anyone who loves history, architecture, and culture. From the iconic Eiffel Tower to the world-famous Louvre Museum, Paris has something to offer for every interest and age.\nThe city is divided into 20 arrondissements, each with its own unique character and charm. The Latin Quarter is a popular area for students and young travelers, while the Champs-Élysées is a hub for shopping and dining. The Montmartre neighborhood is famous for its bohemian vibe and stunning views of the city.\nParis is also known for its beautiful parks and gardens, such as the Luxembourg Gardens and the Tuileries Garden. The city has a rich history, with landmarks like the Notre-Dame Cathedral and the Arc de Triomphe. Visitors can also explore the city's many museums, including the Musée d'Orsay and the Musée Rodin.\nIn addition to its cultural and historical attractions, Paris is also a great destination for foodies. The city is famous for its cuisine, including croissants, baguettes, and cheese. Visitors can sample the city's famous dishes at one of the many restaurants, cafes, and " + ] + tok_ids = [llm_spec.tokenizer.encode(prompts[0])] + else: + prompts = [ + "The capital of France is", + "The president of the United States is", + ] + tok_ids = [llm_spec.tokenizer.encode("The future of AI is")] + if multi_batch: + tok_ids.append(llm_spec.tokenizer.encode(prompts)) + + sampling_params = SamplingParams(max_tokens=128, temperature=0) + for i in range(len(tok_ids)): + num_tokens = 0 + num_drafted = 0 + num_accepted = 0 + + for output in llm_spec.generate_async(tok_ids[i], + sampling_params, + streaming=True): + new_tokens = output.outputs[0].token_ids + num_drafted += max_draft_len + num_accepted += len(new_tokens) - num_tokens - 1 + num_tokens = len(new_tokens) + + accept_rate = num_accepted / num_drafted + assert accept_rate > 0.15 + + # Output tests + sampling_params = SamplingParams(max_tokens=10, temperature=0) + + results_spec = llm_spec.generate(prompts, sampling_params) + generated_text_spec = [result.outputs[0].text for result in results_spec] + llm_spec.shutdown() + + llm_ref = LLM(**llm_common_config) + results_ref = llm_ref.generate(prompts, sampling_params) + generated_text_ref = [result.outputs[0].text for result in results_ref] + llm_ref.shutdown() + + for text_spec, text_ref in zip(generated_text_spec, generated_text_ref): + # The spec decode algorithm currently guarantees identical results + assert text_spec == text_ref + + if __name__ == "__main__": unittest.main()