Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions cpp/kernels/fmha_v2/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3063,7 +3063,9 @@ def get_kernel_traits_code(specs_names):
# 2. Hopper sm89 with e4m3/e4m3_fp32 dtype uses cubins for accuracy regressions (will be fixed).
# You should set the condition `use_cubin_header` to false if you have modified the source codes of those kernels that use cubins.
# This ensures that the kernels will be recompiled using the updated source code rather than relying on precompiled cubins.
def use_cubin_header(sm, head_size, dtype):
def use_cubin_header(sm, head_size, dtype, output_dtype=None):
if 'e4m3' in dtype and output_dtype in ['bf16', 'fp16']:
return False
return (sm == 90 and head_size == 128) or (sm == 89 and 'e4m3' in dtype)


Expand All @@ -3074,7 +3076,7 @@ def get_cubin_header(kernel_traits, specs_names):
cubin_lens_dict = {}
for kspec, fname, lname, kname in specs_names:
if generate_cu_trtllm and not use_cubin_header(
kspec.sm, kspec.head_size, kspec.dtype):
kspec.sm, kspec.head_size, kspec.dtype, kspec.output_dtype):
continue
name = fname.replace('.', '_')
data = 'extern unsigned char cubin_{name}_cubin[];'.format(name=name)
Expand Down Expand Up @@ -3229,7 +3231,8 @@ def get_cubin_header(kernel_traits, specs_names):
if generate_cu_trtllm:

def get_lname_from_kname(kname: str) -> str:
if use_cubin_header(int(sm), int(head_size), prec.lower()):
if use_cubin_header(int(sm), int(head_size), prec.lower(),
output_prec.lower()):
return 'nullptr'
lname = kname.replace('_kernel', '')
mask_types = [
Expand All @@ -3248,8 +3251,9 @@ def get_lname_from_kname(kname: str) -> str:
{cubin_name}_len, \"{kname}\", {smem}, {threads}, {meta_unroll_step}, {attention_mask_type_value}, \
{attention_input_layout_value}, {is_il}, {is_flash_atten}, {is_warp_specialization}, {is_fp32_accu}, \
{is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}, {lname}}}\
'''.format(**locals()) if use_cubin_header(int(sm), int(head_size),
prec.lower()) else '''\
'''.format(**locals()) if use_cubin_header(int(sm),
int(head_size), prec.lower(),
output_prec.lower()) else '''\
{{ DATA_TYPE_{prec}, DATA_TYPE_{output_prec}, {seq_len}, {q_step}, {kv_step}, {head_size}, {head_size_v}, \
{sage_block_sizes[0]}, {sage_block_sizes[1]}, {sage_block_sizes[2]}, kSM_{sm}, nullptr, \
0, \"{kname}\", {smem}, {threads}, {meta_unroll_step}, {attention_mask_type_value}, \
Expand Down Expand Up @@ -3791,7 +3795,7 @@ def enumerate_qgmma_flash_warpspec_kernels(specs,
continue
# for normal attention, we do not need return softmax for ws fp8 kernels currently.
# also fp8 input and bf16 output is only needed for MLA kernel.
skip_combination = return_softmax or (output_dtype is not None)
skip_combination = return_softmax
# for context mla, we need separate qkv as input layout when returning softmax.
skip_mla_combination = return_softmax and input_layout != InputLayout.SEPARATE_Q_K_V
if not skip_combination:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,14 @@ bool DecoderXQAImplJIT::shouldUse(XQAParams const& umbrellaXQAParams, bool forCo
bool hasPerfGain = mayHavePerfGain(xqaParams);
if (!hasPerfGain)
{
if (!xqaParams.is_fp8_output && xqaParams.kv_cache_data_type == DATA_TYPE_E4M3
&& (xqaParams.data_type == DATA_TYPE_BF16 || xqaParams.data_type == DATA_TYPE_FP16))
{
TLLM_LOG_DEBUG(
"JIT XQA is selected in the generation phase for fp16/bf16 input and e4m3 kv cache because MMHA "
"does not support this combination.");
return true;
}
TLLM_LOG_DEBUG("JIT XQA is not used: maybe no performance gain");
return false;
}
Expand Down
112 changes: 112 additions & 0 deletions tests/unittest/_torch/speculative/test_eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,5 +520,117 @@ def test_eagle3_cdl_sampling(disable_overlap_scheduler: bool):
llm_spec.shutdown()


@pytest.mark.parametrize(
"enable_block_reuse,use_one_model,enable_chunked_prefill,fp8_target", [
[True, True, True, True],
])
@pytest.mark.high_cuda_memory
def test_qwen3_eagle3(enable_block_reuse: bool, use_one_model: bool,
enable_chunked_prefill: bool, fp8_target: bool):
# Eagle3 one model works with overlap scheduler and block reuse.
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
if total_mem_gb < 35:
pytest.skip("Not enough memory to load target + draft model")

use_cuda_graph = True
attn_backend = "TRTLLM"
disable_overlap_scheduler = False
use_chain_drafter = True
multi_batch = False
attention_dp = False

models_path = llm_models_root()
eagle_model_dir = "/ziqingc_large/03_Data/models/Zhi-Create-Qwen3-32B-Eagle3" # temp
target_model_dir = f"{models_path}/Qwen3/Qwen3-32B"
if fp8_target:
target_model_dir = f"{models_path}/Qwen3/Qwen3-32B-FP8/"

Comment on lines +543 to +547
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Remove hard-coded absolute draft-model path

eagle_model_dir points to /ziqingc_large/..., which only exists on the author’s workstation. In CI (or any other dev machine) this path resolves to nothing, so the LLM load will throw before the test even starts. Please source the draft checkpoint via llm_models_root() (or a fixture that stages the model) so the test remains runnable everywhere.

Apply this diff as a starting point:

-    eagle_model_dir = "/ziqingc_large/03_Data/models/Zhi-Create-Qwen3-32B-Eagle3"  # temp
+    eagle_model_dir = os.path.join(
+        models_path, "Zhi-Create-Qwen3-32B-Eagle3")

(Adjust the relative directory if the checkpoint sits elsewhere under LLM_MODELS_ROOT.)

🤖 Prompt for AI Agents
In tests/unittest/_torch/speculative/test_eagle3.py around lines 543 to 547,
remove the hard-coded absolute path assigned to eagle_model_dir
(/ziqingc_large/...), and instead construct the path relative to the test's
model root (use llm_models_root() or a provided fixture that stages models).
Replace the literal with a join of llm_models_root() and the relative checkpoint
directory (or use the fixture value) so the test resolves the draft checkpoint
on CI and other machines; adjust the relative subpath as needed to match where
the checkpoint lives under LLM_MODELS_ROOT.

# bs > 1 gives non-deterministic when doing IFB. There are slight chances
# that ref and spec does not match 100%
max_batch_size = 4 if multi_batch else 1
max_draft_len = 3
kv_cache_config = KvCacheConfig(enable_block_reuse=enable_block_reuse,
max_tokens=8192)
if fp8_target:
kv_cache_config.dtype = 'fp8'
cuda_graph_config = CudaGraphConfig(
batch_sizes=[i for i in range(1, max_batch_size +
1)]) if use_cuda_graph else None

llm_common_config = dict(
model=target_model_dir,
attn_backend=attn_backend,
disable_overlap_scheduler=disable_overlap_scheduler,
cuda_graph_config=cuda_graph_config,
max_batch_size=max_batch_size,
kv_cache_config=kv_cache_config,
enable_attention_dp=attention_dp,
max_seq_len=8192,
enable_chunked_prefill=enable_chunked_prefill,
)
if enable_chunked_prefill:
# Use a small max_num_tokens so that the chunked prefill path gets exercised.
llm_common_config['max_num_tokens'] = 64

spec_config = EagleDecodingConfig(
max_draft_len=max_draft_len,
speculative_model_dir=eagle_model_dir,
eagle3_one_model=use_one_model,
)
spec_config._allow_chain_drafter = use_chain_drafter

# Create the LLM instance
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)

# Acceptance rate tests
if enable_chunked_prefill:
# Use a long prompt for chunked prefill tests.
prompts = [
"The capital of France is a city of romance, art, fashion, and cuisine. Paris is a must-visit destination for anyone who loves history, architecture, and culture. From the iconic Eiffel Tower to the world-famous Louvre Museum, Paris has something to offer for every interest and age.\nThe city is divided into 20 arrondissements, each with its own unique character and charm. The Latin Quarter is a popular area for students and young travelers, while the Champs-Élysées is a hub for shopping and dining. The Montmartre neighborhood is famous for its bohemian vibe and stunning views of the city.\nParis is also known for its beautiful parks and gardens, such as the Luxembourg Gardens and the Tuileries Garden. The city has a rich history, with landmarks like the Notre-Dame Cathedral and the Arc de Triomphe. Visitors can also explore the city's many museums, including the Musée d'Orsay and the Musée Rodin.\nIn addition to its cultural and historical attractions, Paris is also a great destination for foodies. The city is famous for its cuisine, including croissants, baguettes, and cheese. Visitors can sample the city's famous dishes at one of the many restaurants, cafes, and "
]
tok_ids = [llm_spec.tokenizer.encode(prompts[0])]
else:
prompts = [
"The capital of France is",
"The president of the United States is",
]
tok_ids = [llm_spec.tokenizer.encode("The future of AI is")]
if multi_batch:
tok_ids.append(llm_spec.tokenizer.encode(prompts))

sampling_params = SamplingParams(max_tokens=128, temperature=0)
for i in range(len(tok_ids)):
num_tokens = 0
num_drafted = 0
num_accepted = 0

for output in llm_spec.generate_async(tok_ids[i],
sampling_params,
streaming=True):
new_tokens = output.outputs[0].token_ids
num_drafted += max_draft_len
num_accepted += len(new_tokens) - num_tokens - 1
num_tokens = len(new_tokens)

accept_rate = num_accepted / num_drafted
assert accept_rate > 0.15

# Output tests
sampling_params = SamplingParams(max_tokens=10, temperature=0)

results_spec = llm_spec.generate(prompts, sampling_params)
generated_text_spec = [result.outputs[0].text for result in results_spec]
llm_spec.shutdown()

llm_ref = LLM(**llm_common_config)
results_ref = llm_ref.generate(prompts, sampling_params)
generated_text_ref = [result.outputs[0].text for result in results_ref]
llm_ref.shutdown()

for text_spec, text_ref in zip(generated_text_spec, generated_text_ref):
# The spec decode algorithm currently guarantees identical results
assert text_spec == text_ref


if __name__ == "__main__":
unittest.main()
Loading