Skip to content

Conversation

jiahanc
Copy link

@jiahanc jiahanc commented Oct 7, 2025

📌 Description

Add SM100 TRTLLM-GEN FP4 and MXFP4 batched gemm for larger concurrencies.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Copy link
Contributor

Summary of Changes

Hello @jiahanc, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly upgrades the FlashInfer library's capabilities for high-performance inference, specifically targeting FP4 Mixture-of-Experts models. It refines the underlying batched GEMM and fused MoE kernels by introducing more precise memory management, flexible CTA scheduling, and advanced configuration options. These changes are designed to boost throughput and optimize resource utilization for complex AI workloads, ensuring better performance and adaptability across different hardware configurations.

Highlights

  • FP4 MoE Batched GEMM Support: This pull request introduces support for FP4 (4-bit floating point) Mixture-of-Experts (MoE) batched General Matrix Multiplications (GEMMs) within the TRTLLMGen framework, aiming to improve throughput.
  • Enhanced MoE Kernel Logic: The fused MoE kernel launchers have been updated to handle different padding requirements for intermediate GEMM outputs (GEMM1 and GEMM2) and to more accurately estimate the maximum number of CTAs (Cooperative Thread Arrays) for efficient execution.
  • Flexible Tile Token Dimension Calculation: The logic for calculating tile_tokens_dim in MoE operations has been made more flexible, allowing for larger tile sizes (up to 128 or 192) and better adaptation to varying token counts, which can improve performance.
  • Advanced Batched GEMM Options: New options have been added to the batched GEMM configuration, including mRouteSfsImpl for granular control over scaling factor routing and CtaSwizzleType for CTA swizzling patterns. Validation rules for 2CTA GEMM configurations and DeepSeek FP8 compatibility have also been refined.
  • TMA OOB Optimization: Improvements to TMA (Tensor Memory Access) Out-Of-Bounds (OOB) optimization have been implemented, utilizing new constants (TmaDimMax, LargeN, XLargeN) to enhance memory access patterns and efficiency.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for FP4 throughput-oriented batched GEMMs for Mixture-of-Experts (MoE) in TRTLLM-Gen. The changes are extensive, touching CUDA kernels, C++ headers, and Python bindings. My review focuses on improving code maintainability by addressing code duplication and redundant checks. I've identified a few areas where refactoring could make the code cleaner and easier to manage.

Comment on lines +396 to +401
int32_t max_num_padded_tokens_gemm1 =
tensorrt_llm::kernels::trtllmgen_moe::Routing::maybeGetMinTokenCount(
max_num_padded_tokens, args.intermediate_size, btg::dtypeGetNumBits(args.mDtypeElt));
int32_t max_num_padded_tokens_gemm2 =
tensorrt_llm::kernels::trtllmgen_moe::Routing::maybeGetMinTokenCount(
max_num_padded_tokens, args.hidden_size, btg::dtypeGetNumBits(args.mDtypeOut));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block for calculating max_num_padded_tokens_gemm1 and max_num_padded_tokens_gemm2 is duplicated in trtllm_fp4_block_scale_moe_launcher (lines 774-779). In fact, the entire function trtllm_fp8_block_scale_moe_launcher is very similar to trtllm_fp4_block_scale_moe_launcher. To improve maintainability and reduce redundancy, consider refactoring the common logic into a templated helper function.

Comment on lines 1117 to 919
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
# Cap to 8-64 tokens per CTA tile
# as it's the range supported by the kernel.
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)

if num_tokens_per_expert > 128 and num_tokens_per_expert < 256:
tile_tokens_dim = 192
# Cap to 8-max_tile_tokens_dim tokens per CTA tile as it's the range supported by the kernel.
tile_tokens_dim = min(max(tile_tokens_dim, 8), max_tile_tokens_dim)
return tile_tokens_dim
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for calculating tile_tokens_dim from num_tokens_per_expert is duplicated here and in flashinfer.utils.calculate_tile_tokens_dim. To avoid code duplication and improve maintainability, you could extract this common logic into a new helper function in flashinfer.utils.

For example:

# in flashinfer/utils.py
def _calculate_tile_dim_from_tokens_per_expert(num_tokens_per_expert: int, max_tile_tokens_dim: int = 128) -> int:
    tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
    if 128 < num_tokens_per_expert < 256:
        tile_tokens_dim = 192
    tile_tokens_dim = min(max(tile_tokens_dim, 8), max_tile_tokens_dim)
    return tile_tokens_dim

Then both calculate_tile_tokens_dim and get_tile_tokens_dim can call this helper function after calculating their respective num_tokens_per_expert.

Comment on lines +1056 to 1059
if (options.mUseDeepSeekFp8) {
TLLM_CHECK_ERROR(options.mClusterDimX == 1, "2CTA Gemm is not supported for DeepSeekFp8");
}
if (options.mUseDeepSeekFp8) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These two consecutive if (options.mUseDeepSeekFp8) blocks can be combined into a single block to improve readability and avoid redundant checks.

Signed-off-by: jiahanc <[email protected]>

update some work

Signed-off-by: jiahanc <[email protected]>

update some work

Signed-off-by: jiahanc <[email protected]>

update some more files

Signed-off-by: jiahanc <[email protected]>
Signed-off-by: jiahanc <[email protected]>
Signed-off-by: jiahanc <[email protected]>
@jiahanc jiahanc force-pushed the trtllmgen_throughput branch from 33fad7f to 36a6f52 Compare October 7, 2025 21:31
Signed-off-by: jiahanc <[email protected]>
@yzh119
Copy link
Collaborator

yzh119 commented Oct 7, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !70 has been created, and the CI pipeline #36210663 is currently running. I'll report back once the pipeline job completes.

@jiahanc jiahanc changed the title feat: Add FP4 TRTLLMGen throughput MOE batched gemms feat: Add FP4 TRTLLM-Gen throughput MOE batched gemms Oct 7, 2025
Signed-off-by: jiahanc <[email protected]>

remove comment

Signed-off-by: jiahanc <[email protected]>
@jiahanc jiahanc force-pushed the trtllmgen_throughput branch from c187427 to 639df1e Compare October 8, 2025 03:09
@jiahanc
Copy link
Author

jiahanc commented Oct 8, 2025

/bot --help

@flashinfer-bot
Copy link
Collaborator

@jiahanc is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww

@yzh119
Copy link
Collaborator

yzh119 commented Oct 8, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !70 has been updated with latest changes, and the CI pipeline #36221439 is currently running. I'll report back once the pipeline job completes.

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Waiting for B300 hanging issue to fix

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #36221439: 1/17 passed, 16 failed (unit_test_5090_dlcluster: [cu130], unit_test_5090_dlcluster: [cu129], unit_test_h100_dlcluster: [cu130], +13 more)

Copy link
Collaborator

@aleozlx aleozlx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm too

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants