Skip to content

Conversation

IwakuraRein
Copy link
Contributor

@IwakuraRein IwakuraRein commented Aug 12, 2025

📌 Description

  • Update the AutoTuner , OptimizationProfile, and DynamicTensorSpec.
    • DynamicTensorSpec can take multiple input tensors.
    • The tensor_initializers in DynamicTensorSpec defines the initialization method for dynamic tensors. Before they were all zero-initialized and this will cause IMA in trtllm-gen's routing kernels.
  • Add DtypeTrtllmGen in flashinfer/fused_moe/core.py
  • Add autotuner to trtllm-gen fp4 moe.
  • Relax the check on hidden_states_scales in trtllm-gen fp4 moe. It doesn't need to be 1D.
    • When autotuning, it must be 2D.

TODOs

  • Unify the launcher for both trtllm-gen fp8 and fp4 moe.
  • After unifying the launchers, add auto tuner to fp8 moe.
  • If routing is DeepSeek V3, there will be illegal memory access. A WAR is to limit the search space to [8 ... 1024]

Performance

B200, clock speed locked at 1500mhz, 1000 warmups, 1000 iterations, mxfp4 x mxfp8

num_tokens wo tuner with tuner diff
1 0.042 0.042 0.00%
2 0.057 0.057 0.00%
4 0.081 0.081 0.00%
8 0.114 0.105 7.89%
16 0.201 0.183 8.96%
32 0.274 0.246 10.22%
64 0.348 0.308 11.49%
128 0.41 0.365 10.98%
256 0.548 0.429 21.72%
512 0.576 0.453 21.35%
1024 0.651 0.578 11.21%

For nvfp4 x nvfp4 and mxfp4 x bf16, there is no significant perf gain.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @IwakuraRein, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces an autotuning mechanism for FP4 Mixture-of-Experts (MoE) kernels generated by TensorRT-LLM (trtllm-gen). It enables the system to dynamically select optimal kernel configurations for improved performance. The changes involve extending the C++ kernel interface to accept specific configurations, implementing a Python-based MoERunner to manage and cache valid configurations, and integrating this runner with the existing autotuner framework. Additionally, the tensor creation logic for autotuning is enhanced to use more realistic random data, which can lead to more accurate tuning results.

Highlights

  • Autotuning for FP4 MoE Kernels: A comprehensive autotuning framework has been introduced for trtllm-gen FP4 Mixture-of-Experts (MoE) kernels, allowing the system to dynamically find and apply optimal kernel configurations for enhanced performance.
  • Dynamic Kernel Configuration: The underlying C++ trtllm_fp4_block_scale_moe function now accepts a config_index, enabling the dynamic selection of specific, pre-optimized kernel configurations at runtime.
  • MoERunner Implementation: A new MoERunner class has been implemented to manage and cache valid kernel configurations. This includes logic for calculating optimal tile_tokens_dim and querying available tactics from the C++ backend, streamlining the tuning process.
  • Enhanced Autotuner Input Data: The _create_tensor_like utility in the autotuner has been updated to generate random data for integer and floating-point types, providing more realistic input tensors during tuning and potentially leading to more accurate performance optimizations.
  • Integration with AutoTuner: The Python frontend functions (trtllm_fp4_block_scale_moe_op, trtllm_fp4_block_scale_moe, trtllm_fp4_block_scale_routed_moe) have been modified to seamlessly integrate with the AutoTuner and MoERunner, ensuring that the best-performing kernel configuration is automatically selected and used.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an autotuner for the Trtllm-gen FP4 MoE kernels, which is a valuable addition for performance optimization. The changes span both the C++ backend and the Python-level autotuner infrastructure. The overall approach is sound, but there are a few issues that need to be addressed, including a critical bug that could cause a runtime error, some minor code correctness issues, and leftover debugging code. Addressing these points will improve the robustness and quality of the implementation.

Comment on lines 1084 to 1091
extra_input_idx = 0
if trtllm_gen_dtype_has_scale(self.dtype_act):
hidden_states_scale = extra_inputs[extra_input_idx]
extra_input_idx += 1
if trtllm_gen_dtype_has_scale(self.dtype_weights):
gemm1_weights_scale = extra_inputs[extra_input_idx]
gemm2_weights_scale = extra_inputs[extra_input_idx + 1]
extra_input_idx += 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The variables hidden_states_scale, gemm1_weights_scale, and gemm2_weights_scale are only defined within if blocks. If the conditions are false, these variables will not be defined, leading to a NameError when they are used in the moe_op.trtllm_fp4_block_scale_moe call on line 1094. You should initialize them to None before the conditional blocks to ensure they are always defined.

            extra_input_idx = 0
            hidden_states_scale = None
            if trtllm_gen_dtype_has_scale(self.dtype_act):
                hidden_states_scale = extra_inputs[extra_input_idx]
                extra_input_idx += 1
            gemm1_weights_scale = None
            gemm2_weights_scale = None
            if trtllm_gen_dtype_has_scale(self.dtype_weights):
                gemm1_weights_scale = extra_inputs[extra_input_idx]
                gemm2_weights_scale = extra_inputs[extra_input_idx + 1]
                extra_input_idx += 2

default:
TORCH_CHECK(false, "Invalid trtllm-gen dtype");
}
return btg::Dtype::E2m1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This line is unreachable because the default case in the switch statement above it will always throw an exception with TORCH_CHECK(false, ...). The compiler should have issued a warning about this. Please remove this line.

Comment on lines +94 to +100
if dtype in [
DtypeTrtllmGen.MxE4m3,
DtypeTrtllmGen.E2m1,
DtypeTrtllmGen.MxE2m1,
DtypeTrtllmGen.MxE4m3,
]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The value DtypeTrtllmGen.MxE4m3 is duplicated in this list. While this doesn't cause a functional issue, it's redundant and should be cleaned up for better code clarity.

    if dtype in [
        DtypeTrtllmGen.MxE4m3,
        DtypeTrtllmGen.E2m1,
        DtypeTrtllmGen.MxE2m1,
    ]:

self.dtype_act = dtype_act
self.dtype_weights = dtype_weights
self.use_deepseek_fp8 = use_deepseek_fp8
self.top_k = top_k
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The assignment self.top_k = top_k on line 986 is a duplicate of the assignment on line 982. Please remove the redundant line.

self.num_experts,
num_tokens,
)
print(f"instance_key: {instance_key}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This print statement appears to be for debugging purposes. It should be removed from the final code to avoid polluting the output during execution.

Comment on lines 1418 to 1422
tunning_config = (
MoERunner.tuning_config_no_hidden_states_scales
if hidden_states_scale is None
else MoERunner.tuning_config_with_hidden_states_scales
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a typo in the variable name tunning_config. It should be tuning_config for consistency with the rest of the autotuner code.

            tuning_config = (
                MoERunner.tuning_config_no_hidden_states_scales
                if hidden_states_scale is None
                else MoERunner.tuning_config_with_hidden_states_scales
            )

@IwakuraRein IwakuraRein force-pushed the trtllm-gen-moe-autotuner branch from e4fb808 to e533cac Compare August 13, 2025 23:15
@nvpohanh
Copy link
Contributor

cc @amirkl94

@IwakuraRein IwakuraRein force-pushed the trtllm-gen-moe-autotuner branch from 4d0b914 to 37316e9 Compare August 14, 2025 19:09
@IwakuraRein IwakuraRein marked this pull request as ready for review August 14, 2025 19:09

inline btg::Dtype get_dtype(int64_t const dtype) {
switch (dtype) {
case 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this currenly finds the type by Dtype's Uid bits.
would it be more maintainable if we simply pass strings (so in code review, the case will more obviously match the type name) and document at the binding interface that the available options are in DtypeDecl.h Dtype

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the current approach has the advantage of being checked early using class DtypeTrtllmGen(IntEnum). so no objections, just raising options

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aleozlx Thanks for the suggestion! What about defining a new enum in trtllm_fused_moe_kernel_launcher.cu and use macros to map it to btg::Dtype. Then expose to python.

Copy link
Contributor

@aleozlx aleozlx Aug 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure . that improves readability. or we can use class Dtype(IntEnum) i proposed in the other thread and remove another conversion. and hopefully __new__ serves to hide impl details in a way that user won't rely on the value (compared to a separate function). although the value changes can cause breakage thru the cpp interface, so we just have to align on something early.

@@ -65,6 +66,43 @@ class RoutingMethodType(IntEnum):
Unspecified = 5


# NOTE(siyuan): Need to keep this in sync with the counterpart defined in include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h
class DtypeTrtllmGen(IntEnum):
Bfloat16 = (0,)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or we can borrow the bits formation from TLLM_ENCODE_DTYPE, and allowing the deletion of get_dtype() conversion

e.g.

class Dtype(IntEnum):
    def __new__(cls, block_format_bit, signed_bit, integer_bit, num_bits, uid):
        value = (block_format_bit << 24) | (signed_bit << 20) | (integer_bit << 16) | (num_bits << 8) | uid
        obj = int.__new__(cls, value)
        obj._value_ = value
        return obj
    
    # keep the values in sync with include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h
    Bfloat16 = (0, 1, 0, 16, 0)
    Bool     = (0, 0, 1, 1,  1)
    E2m1     = (1, 1, 0, 4,  2)
    E2m3     = (1, 1, 0, 6,  3)
    E3m2     = (1, 1, 0, 6,  4)
    E4m3     = (0, 1, 0, 8,  5)
    E5m2     = (0, 1, 0, 8,  6)
    Fp16     = (0, 1, 0, 16, 7)
    Fp32     = (0, 1, 0, 32, 8)
    Int8     = (0, 1, 1, 8,  9)
    Int32    = (0, 1, 1, 32, 10)
    Int64    = (0, 1, 1, 64, 11)
    MxE2m1   = (1, 1, 0, 4,  12)
    MxE4m3   = (1, 1, 0, 8,  13)
    UE8m0    = (0, 0, 0, 8,  14)
    UInt8    = (0, 0, 1, 8,  15)
    UInt16   = (0, 0, 1, 16, 16)
    UInt32   = (0, 0, 1, 32, 17)
    UInt64   = (0, 0, 1, 64, 18)
    UInt128  = (0, 0, 1, 128, 19)
    Void     = (0, 1, 0, 0,  20)

@@ -1481,6 +1840,7 @@ def trtllm_fp4_block_scale_routed_moe(
routing_method_type: int = 0,
do_finalize: bool = True,
enable_pdl: Optional[bool] = None,
tune_max_num_tokens: int = 1024,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we always append? so if users are using the optional arguments positionally, they won't break.

or i guess we should have put optionals after , /, ... in hind sight to prevent them from being used positionally, which is easier to break

@aleozlx aleozlx mentioned this pull request Aug 15, 2025
5 tasks
dtype = DtypeTrtllmGen.Bfloat16
elif x.dtype == torch.float8_e4m3fn:
dtype = DtypeTrtllmGen.E4m3 if scale is None else DtypeTrtllmGen.MxE4m3
elif x.dtype == torch.uint8:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a minor note: we should also take care of torch.float4_e2m1x2 for torch 2.8+

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. But I didn't see float4_e2m1fn_x2 is used anywhere in the flashinfer? Will we add it altogether in the future?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not yet, we need to prepare for it when frameworks are all upgrading to torch 2.8.

It could be done in later PRs.

@yzh119
Copy link
Collaborator

yzh119 commented Aug 19, 2025

There are some conflicts with main branch after #1396 got merged, would you mind rebasing?

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, next step is to get #1494 in, which is dependent on this one.

@yzh119 yzh119 changed the title Trtllm-gen Fp4 MoE Autotunner tuner: Trtllm-gen Fp4 MoE Autotunner Aug 19, 2025
@yzh119 yzh119 merged commit f1fd5c6 into flashinfer-ai:main Aug 19, 2025
2 checks passed
assert hidden_states.shape[0] == num_tokens, (
"hidden_states's first dimension must be batch size."
)
assert hidden_states_scale is None or (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assertion make vllm gpt-oss fail. In that case the scale is tensor([1., 1., 1., ..., 1., 1., 1.], device='cuda:0',dtype=torch.float8_e4m3fn)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@weireweire I created this pr to fix this. Currently vllm's flashinfer tag is 0.2.12. I can mark this pr as ready and increment the flashinfer tag to 0.2.13


# TODO(siyuan): support fp8
moe_op.trtllm_fp4_block_scale_moe(
routing_logits.to(torch.bfloat16),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line actually breaks DeepSeek v3 routing. I have left suggest changes in #1494

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants