-
Notifications
You must be signed in to change notification settings - Fork 532
feat: Add FP4 TRTLLM-Gen throughput MOE batched gemms #1882
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @jiahanc, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly upgrades the FlashInfer library's capabilities for high-performance inference, specifically targeting FP4 Mixture-of-Experts models. It refines the underlying batched GEMM and fused MoE kernels by introducing more precise memory management, flexible CTA scheduling, and advanced configuration options. These changes are designed to boost throughput and optimize resource utilization for complex AI workloads, ensuring better performance and adaptability across different hardware configurations. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for FP4 throughput-oriented batched GEMMs for Mixture-of-Experts (MoE) in TRTLLM-Gen. The changes are extensive, touching CUDA kernels, C++ headers, and Python bindings. My review focuses on improving code maintainability by addressing code duplication and redundant checks. I've identified a few areas where refactoring could make the code cleaner and easier to manage.
int32_t max_num_padded_tokens_gemm1 = | ||
tensorrt_llm::kernels::trtllmgen_moe::Routing::maybeGetMinTokenCount( | ||
max_num_padded_tokens, args.intermediate_size, btg::dtypeGetNumBits(args.mDtypeElt)); | ||
int32_t max_num_padded_tokens_gemm2 = | ||
tensorrt_llm::kernels::trtllmgen_moe::Routing::maybeGetMinTokenCount( | ||
max_num_padded_tokens, args.hidden_size, btg::dtypeGetNumBits(args.mDtypeOut)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block for calculating max_num_padded_tokens_gemm1
and max_num_padded_tokens_gemm2
is duplicated in trtllm_fp4_block_scale_moe_launcher
(lines 774-779). In fact, the entire function trtllm_fp8_block_scale_moe_launcher
is very similar to trtllm_fp4_block_scale_moe_launcher
. To improve maintainability and reduce redundancy, consider refactoring the common logic into a templated helper function.
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert) | ||
# Cap to 8-64 tokens per CTA tile | ||
# as it's the range supported by the kernel. | ||
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) | ||
|
||
if num_tokens_per_expert > 128 and num_tokens_per_expert < 256: | ||
tile_tokens_dim = 192 | ||
# Cap to 8-max_tile_tokens_dim tokens per CTA tile as it's the range supported by the kernel. | ||
tile_tokens_dim = min(max(tile_tokens_dim, 8), max_tile_tokens_dim) | ||
return tile_tokens_dim |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for calculating tile_tokens_dim
from num_tokens_per_expert
is duplicated here and in flashinfer.utils.calculate_tile_tokens_dim
. To avoid code duplication and improve maintainability, you could extract this common logic into a new helper function in flashinfer.utils
.
For example:
# in flashinfer/utils.py
def _calculate_tile_dim_from_tokens_per_expert(num_tokens_per_expert: int, max_tile_tokens_dim: int = 128) -> int:
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
if 128 < num_tokens_per_expert < 256:
tile_tokens_dim = 192
tile_tokens_dim = min(max(tile_tokens_dim, 8), max_tile_tokens_dim)
return tile_tokens_dim
Then both calculate_tile_tokens_dim
and get_tile_tokens_dim
can call this helper function after calculating their respective num_tokens_per_expert
.
if (options.mUseDeepSeekFp8) { | ||
TLLM_CHECK_ERROR(options.mClusterDimX == 1, "2CTA Gemm is not supported for DeepSeekFp8"); | ||
} | ||
if (options.mUseDeepSeekFp8) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Signed-off-by: jiahanc <[email protected]> update some work Signed-off-by: jiahanc <[email protected]> update some work Signed-off-by: jiahanc <[email protected]> update some more files Signed-off-by: jiahanc <[email protected]>
Signed-off-by: jiahanc <[email protected]>
Signed-off-by: jiahanc <[email protected]>
33fad7f
to
36a6f52
Compare
Signed-off-by: jiahanc <[email protected]>
/bot run |
Signed-off-by: jiahanc <[email protected]> remove comment Signed-off-by: jiahanc <[email protected]>
c187427
to
639df1e
Compare
/bot --help |
@jiahanc is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww |
/bot run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Waiting for B300 hanging issue to fix
[FAILED] Pipeline #36221439: 1/17 passed, 16 failed (unit_test_5090_dlcluster: [cu130], unit_test_5090_dlcluster: [cu129], unit_test_h100_dlcluster: [cu130], +13 more) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm too
📌 Description
Add SM100 TRTLLM-GEN FP4 and MXFP4 batched gemm for larger concurrencies.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commit
by runningpip install pre-commit
(or used your preferred method).pre-commit install
.pre-commit run --all-files
and fixed any reported issues.🧪 Tests
unittest
, etc.).Reviewer Notes