-
Notifications
You must be signed in to change notification settings - Fork 471
tuner: Trtllm-gen Fp4 MoE Autotunner #1475
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
tuner: Trtllm-gen Fp4 MoE Autotunner #1475
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @IwakuraRein, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces an autotuning mechanism for FP4 Mixture-of-Experts (MoE) kernels generated by TensorRT-LLM (trtllm-gen). It enables the system to dynamically select optimal kernel configurations for improved performance. The changes involve extending the C++ kernel interface to accept specific configurations, implementing a Python-based MoERunner
to manage and cache valid configurations, and integrating this runner with the existing autotuner framework. Additionally, the tensor creation logic for autotuning is enhanced to use more realistic random data, which can lead to more accurate tuning results.
Highlights
- Autotuning for FP4 MoE Kernels: A comprehensive autotuning framework has been introduced for
trtllm-gen
FP4 Mixture-of-Experts (MoE) kernels, allowing the system to dynamically find and apply optimal kernel configurations for enhanced performance. - Dynamic Kernel Configuration: The underlying C++
trtllm_fp4_block_scale_moe
function now accepts aconfig_index
, enabling the dynamic selection of specific, pre-optimized kernel configurations at runtime. MoERunner
Implementation: A newMoERunner
class has been implemented to manage and cache valid kernel configurations. This includes logic for calculating optimaltile_tokens_dim
and querying available tactics from the C++ backend, streamlining the tuning process.- Enhanced Autotuner Input Data: The
_create_tensor_like
utility in the autotuner has been updated to generate random data for integer and floating-point types, providing more realistic input tensors during tuning and potentially leading to more accurate performance optimizations. - Integration with
AutoTuner
: The Python frontend functions (trtllm_fp4_block_scale_moe_op
,trtllm_fp4_block_scale_moe
,trtllm_fp4_block_scale_routed_moe
) have been modified to seamlessly integrate with theAutoTuner
andMoERunner
, ensuring that the best-performing kernel configuration is automatically selected and used.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces an autotuner for the Trtllm-gen FP4 MoE kernels, which is a valuable addition for performance optimization. The changes span both the C++ backend and the Python-level autotuner infrastructure. The overall approach is sound, but there are a few issues that need to be addressed, including a critical bug that could cause a runtime error, some minor code correctness issues, and leftover debugging code. Addressing these points will improve the robustness and quality of the implementation.
flashinfer/fused_moe/core.py
Outdated
extra_input_idx = 0 | ||
if trtllm_gen_dtype_has_scale(self.dtype_act): | ||
hidden_states_scale = extra_inputs[extra_input_idx] | ||
extra_input_idx += 1 | ||
if trtllm_gen_dtype_has_scale(self.dtype_weights): | ||
gemm1_weights_scale = extra_inputs[extra_input_idx] | ||
gemm2_weights_scale = extra_inputs[extra_input_idx + 1] | ||
extra_input_idx += 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variables hidden_states_scale
, gemm1_weights_scale
, and gemm2_weights_scale
are only defined within if
blocks. If the conditions are false, these variables will not be defined, leading to a NameError
when they are used in the moe_op.trtllm_fp4_block_scale_moe
call on line 1094. You should initialize them to None
before the conditional blocks to ensure they are always defined.
extra_input_idx = 0
hidden_states_scale = None
if trtllm_gen_dtype_has_scale(self.dtype_act):
hidden_states_scale = extra_inputs[extra_input_idx]
extra_input_idx += 1
gemm1_weights_scale = None
gemm2_weights_scale = None
if trtllm_gen_dtype_has_scale(self.dtype_weights):
gemm1_weights_scale = extra_inputs[extra_input_idx]
gemm2_weights_scale = extra_inputs[extra_input_idx + 1]
extra_input_idx += 2
default: | ||
TORCH_CHECK(false, "Invalid trtllm-gen dtype"); | ||
} | ||
return btg::Dtype::E2m1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if dtype in [ | ||
DtypeTrtllmGen.MxE4m3, | ||
DtypeTrtllmGen.E2m1, | ||
DtypeTrtllmGen.MxE2m1, | ||
DtypeTrtllmGen.MxE4m3, | ||
]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.dtype_act = dtype_act | ||
self.dtype_weights = dtype_weights | ||
self.use_deepseek_fp8 = use_deepseek_fp8 | ||
self.top_k = top_k |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
flashinfer/fused_moe/core.py
Outdated
self.num_experts, | ||
num_tokens, | ||
) | ||
print(f"instance_key: {instance_key}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
flashinfer/fused_moe/core.py
Outdated
tunning_config = ( | ||
MoERunner.tuning_config_no_hidden_states_scales | ||
if hidden_states_scale is None | ||
else MoERunner.tuning_config_with_hidden_states_scales | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
e4fb808
to
e533cac
Compare
cc @amirkl94 |
4d0b914
to
37316e9
Compare
|
||
inline btg::Dtype get_dtype(int64_t const dtype) { | ||
switch (dtype) { | ||
case 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this currenly finds the type by Dtype's Uid bits.
would it be more maintainable if we simply pass strings (so in code review, the case will more obviously match the type name) and document at the binding interface that the available options are in DtypeDecl.h Dtype
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the current approach has the advantage of being checked early using class DtypeTrtllmGen(IntEnum). so no objections, just raising options
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@aleozlx Thanks for the suggestion! What about defining a new enum in trtllm_fused_moe_kernel_launcher.cu
and use macros to map it to btg::Dtype
. Then expose to python.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure . that improves readability. or we can use class Dtype(IntEnum)
i proposed in the other thread and remove another conversion. and hopefully __new__
serves to hide impl details in a way that user won't rely on the value (compared to a separate function). although the value changes can cause breakage thru the cpp interface, so we just have to align on something early.
flashinfer/fused_moe/core.py
Outdated
@@ -65,6 +66,43 @@ class RoutingMethodType(IntEnum): | |||
Unspecified = 5 | |||
|
|||
|
|||
# NOTE(siyuan): Need to keep this in sync with the counterpart defined in include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h | |||
class DtypeTrtllmGen(IntEnum): | |||
Bfloat16 = (0,) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or we can borrow the bits formation from TLLM_ENCODE_DTYPE, and allowing the deletion of get_dtype() conversion
e.g.
class Dtype(IntEnum):
def __new__(cls, block_format_bit, signed_bit, integer_bit, num_bits, uid):
value = (block_format_bit << 24) | (signed_bit << 20) | (integer_bit << 16) | (num_bits << 8) | uid
obj = int.__new__(cls, value)
obj._value_ = value
return obj
# keep the values in sync with include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h
Bfloat16 = (0, 1, 0, 16, 0)
Bool = (0, 0, 1, 1, 1)
E2m1 = (1, 1, 0, 4, 2)
E2m3 = (1, 1, 0, 6, 3)
E3m2 = (1, 1, 0, 6, 4)
E4m3 = (0, 1, 0, 8, 5)
E5m2 = (0, 1, 0, 8, 6)
Fp16 = (0, 1, 0, 16, 7)
Fp32 = (0, 1, 0, 32, 8)
Int8 = (0, 1, 1, 8, 9)
Int32 = (0, 1, 1, 32, 10)
Int64 = (0, 1, 1, 64, 11)
MxE2m1 = (1, 1, 0, 4, 12)
MxE4m3 = (1, 1, 0, 8, 13)
UE8m0 = (0, 0, 0, 8, 14)
UInt8 = (0, 0, 1, 8, 15)
UInt16 = (0, 0, 1, 16, 16)
UInt32 = (0, 0, 1, 32, 17)
UInt64 = (0, 0, 1, 64, 18)
UInt128 = (0, 0, 1, 128, 19)
Void = (0, 1, 0, 0, 20)
flashinfer/fused_moe/core.py
Outdated
@@ -1481,6 +1840,7 @@ def trtllm_fp4_block_scale_routed_moe( | |||
routing_method_type: int = 0, | |||
do_finalize: bool = True, | |||
enable_pdl: Optional[bool] = None, | |||
tune_max_num_tokens: int = 1024, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we always append? so if users are using the optional arguments positionally, they won't break.
or i guess we should have put optionals after , /, ...
in hind sight to prevent them from being used positionally, which is easier to break
flashinfer/fused_moe/core.py
Outdated
dtype = DtypeTrtllmGen.Bfloat16 | ||
elif x.dtype == torch.float8_e4m3fn: | ||
dtype = DtypeTrtllmGen.E4m3 if scale is None else DtypeTrtllmGen.MxE4m3 | ||
elif x.dtype == torch.uint8: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a minor note: we should also take care of torch.float4_e2m1x2
for torch 2.8+
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. But I didn't see float4_e2m1fn_x2
is used anywhere in the flashinfer? Will we add it altogether in the future?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not yet, we need to prepare for it when frameworks are all upgrading to torch 2.8.
It could be done in later PRs.
There are some conflicts with main branch after #1396 got merged, would you mind rebasing? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, next step is to get #1494 in, which is dependent on this one.
assert hidden_states.shape[0] == num_tokens, ( | ||
"hidden_states's first dimension must be batch size." | ||
) | ||
assert hidden_states_scale is None or ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This assertion make vllm gpt-oss fail. In that case the scale is tensor([1., 1., 1., ..., 1., 1., 1.], device='cuda:0',dtype=torch.float8_e4m3fn)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@weireweire I created this pr to fix this. Currently vllm's flashinfer tag is 0.2.12. I can mark this pr as ready and increment the flashinfer tag to 0.2.13
|
||
# TODO(siyuan): support fp8 | ||
moe_op.trtllm_fp4_block_scale_moe( | ||
routing_logits.to(torch.bfloat16), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line actually breaks DeepSeek v3 routing. I have left suggest changes in #1494
📌 Description
AutoTuner
,OptimizationProfile
, andDynamicTensorSpec
.DynamicTensorSpec
can take multiple input tensors.tensor_initializers
inDynamicTensorSpec
defines the initialization method for dynamic tensors. Before they were all zero-initialized and this will cause IMA in trtllm-gen's routing kernels.DtypeTrtllmGen
inflashinfer/fused_moe/core.py
hidden_states_scales
in trtllm-gen fp4 moe. It doesn't need to be 1D.TODOs
Performance
B200, clock speed locked at 1500mhz, 1000 warmups, 1000 iterations,
mxfp4 x mxfp8
For
nvfp4 x nvfp4
andmxfp4 x bf16
, there is no significant perf gain.🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commit
by runningpip install pre-commit
(or used your preferred method).pre-commit install
.pre-commit run --all-files
and fixed any reported issues.🧪 Tests
unittest
, etc.).Reviewer Notes