Skip to content

Conversation

djmmoss
Copy link
Contributor

@djmmoss djmmoss commented Sep 30, 2025

πŸ“Œ Description

Adds FP8 Block Scaling Fused Cutlass MoE for SM90.

βœ… Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

πŸ§ͺ Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Copy link
Contributor

Summary of Changes

Hello @djmmoss, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the Fused Cutlass Mixture-of-Experts (MoE) functionality by introducing FP8 Block Scaling capabilities, primarily for SM90 architectures. It integrates a sophisticated DeepGEMM JIT compilation system that dynamically generates and optimizes CUDA kernels for efficient FP8 GEMM operations. This allows for more flexible and performant FP8 computations in MoE layers, moving from a stub implementation to a full-fledged feature with comprehensive kernel support and updated build configurations.

Highlights

  • FP8 Block Scaling for MoE: Introduced support for FP8 Block Scaling in Fused Cutlass Mixture-of-Experts (MoE) operations, specifically targeting SM90 (Hopper) architectures.
  • DeepGEMM JIT Integration: Integrated a DeepGEMM JIT compilation framework, enabling dynamic generation and optimization of CUDA kernels for FP8 GEMM operations.
  • New CUDA Kernels: Added several new CUDA kernel files (fp8_gemm_impl.cuh, sm89_fp8_gemm_1d1d.cuh, fp8_blockscale_gemm_kernel.cuh, etc.) to handle the specialized FP8 GEMM computations and block-wise scaling.
  • Build System Updates: Modified the build process in flashinfer/fused_moe/core.py to include the new FP8 blockscale GEMM implementation and added -lnvrtc to linker flags for NVRTC support.
  • Test Case Enhancements: Updated existing test cases in tests/test_trtllm_cutlass_fused_moe.py to validate the new FP8 block scaling functionality, including changes to scale initialization and output handling.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with πŸ‘ and πŸ‘Ž on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for FP8 block scaling in fused Mixture-of-Experts (MoE) kernels for SM90 architectures, which is a significant feature enhancement. The implementation includes a new JIT compilation framework for CUDA kernels, the FP8 GEMM kernel implementations, and updates to the build system and tests. The code is generally well-structured, but I've identified a few issues that should be addressed to improve correctness and maintainability, including a bug in kernel name generation, an inconsistent function declaration, and a redundant loop.

Comment on lines +314 to +315
std::to_string(num_groups) + "_" + std::to_string(num_stages) +
std::to_string(num_groups) + "_" + std::to_string(num_stages) + "_" +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a duplication in the construction of the kernel name. The num_groups and num_stages parameters are appended twice. This appears to be a copy-paste error and will result in an incorrect and unnecessarily long kernel name, which could lead to caching issues.

                       std::to_string(num_groups) + "_" + std::to_string(num_stages) + "_" +


namespace deep_gemm {
template <typename T>
static CUtensorMap make_2d_tma_a_desc(T* global_address, uint32_t shape_m, uint32_t shape_k,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function make_2d_tma_a_desc is declared static, while other similar helper functions in this file (e.g., make_2d_tma_b_desc) are not. This is inconsistent. For function templates in header files, the static keyword is generally not necessary and can be removed for consistency and adherence to modern C++ practices.

CUtensorMap make_2d_tma_a_desc(T* global_address, uint32_t shape_m, uint32_t shape_k,

Comment on lines +117 to +120
for (int i = 0; i < num_problems; i++) {
fp8_mat_b = reinterpret_cast<__nv_fp8_e4m3*>(const_cast<void*>(mat_b));
per_block_scales = const_cast<float*>(scales_b);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This loop is redundant. It repeatedly assigns the same pointer values to fp8_mat_b and per_block_scales in each iteration. The loop can be removed and replaced with a single assignment to improve clarity and remove dead code.

    fp8_mat_b = reinterpret_cast<__nv_fp8_e4m3*>(const_cast<void*>(mat_b));
    per_block_scales = const_cast<float*>(scales_b);

Signed-off-by: Duncan Moss <[email protected]>
Signed-off-by: Duncan Moss <[email protected]>
"-O3",
"-std=c++17",
"-Wno-switch-bool",
"-D__CUDACC_VER_MAJOR__=" + str(torch.version.cuda.split(".")[0]),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this? They are nvcc native macros.

Copy link
Contributor Author

@djmmoss djmmoss Sep 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I was compiling, they weren't enabled in the JIT mode, I'm not sure it this was an environment issue, but I just installed the package with pip install -e . in a fresh venv. I believe it's because this is c++ instead of nvcc?

E           [2/3] c++ -MMD -MF fp4_quantization_90/fp4Op.o.d -DTORCH_EXTENSION_NAME=fp4_quantization_90 -DTORCH_API_INCLUDE_EXTENSION_H -DPy_LIMITED_API=0x03090000 -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1018\" -D_GLIBCXX_USE_CXX11_ABI=1 -I/home/scratch.dmoss_gpu_1/repos/squirtle/flashinfer/csrc/nv_internal -I/home/scratch.dmoss_gpu_1/repos/squirtle/flashinfer/csrc/nv_internal/include -isystem /usr/include/python3.10 -isystem /home/scratch.dmoss_gpu_1/repos/squirtle/flashinfer/venv/lib/python3.10/site-packages/torch/include -isystem /home/scratch.dmoss_gpu_1/repos/squirtle/flashinfer/venv/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -isystem /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest_r12.9/include -isystem /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest_r12.9/include/cccl -isystem /home/scratch.dmoss_gpu_1/repos/squirtle/flashinfer/include -isystem /home/scratch.dmoss_gpu_1/repos/squirtle/flashinfer/csrc -isystem /home/scratch.dmoss_gpu_1/repos/squirtle/flashinfer/3rdparty/cutlass/include -isystem /home/scratch.dmoss_gpu_1/repos/squirtle/flashinfer/3rdparty/cutlass/tools/util/include -isystem /home/scratch.dmoss_gpu_1/repos/squirtle/flashinfer/3rdparty/spdlog/include -fPIC -O3 -std=c++17 -Wno-switch-bool -DENABLE_BF16 -DENABLE_FP8 -DENABLE_FP4 -c /home/scratch.dmoss_gpu_1/repos/squirtle/flashinfer/csrc/nv_internal/tensorrt_llm/thop/fp4Op.cpp -o fp4_quantization_90/fp4Op.o 
E           FAILED: [code=1] fp4_quantization_90/fp4Op.o 
E           c++ -MMD -MF fp4_quantization_90/fp4Op.o.d -DTORCH_EXTENSION_NAME=fp4_quantization_90 -DTORCH_API_INCLUDE_EXTENSION_H -DPy_LIMITED_API=0x03090000 -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1018\" -D_GLIBCXX_USE_CXX11_ABI=1 -I/home/scratch.dmoss_gpu_1/repos/squirtle/flashinfer/csrc/nv_internal -I/home/scratch.dmoss_gpu_1/repos/squirtle/flashinfer/csrc/nv_internal/include -isystem /usr/include/python3.10 -isystem /home/scratch.dmoss_gpu_1/repos/squirtle/flashinfer/venv/lib/python3.10/site-packages/torch/include -isystem /home/scratch.dmoss_gpu_1/repos/squirtle/flashinfer/venv/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -isystem /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest_r12.9/include -isystem /home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest_r12.9/include/cccl -isystem /home/scratch.dmoss_gpu_1/repos/squirtle/flashinfer/include -isystem /home/scratch.dmoss_gpu_1/repos/squirtle/flashinfer/csrc -isystem /home/scratch.dmoss_gpu_1/repos/squirtle/flashinfer/3rdparty/cutlass/include -isystem /home/scratch.dmoss_gpu_1/repos/squirtle/flashinfer/3rdparty/cutlass/tools/util/include -isystem /home/scratch.dmoss_gpu_1/repos/squirtle/flashinfer/3rdparty/spdlog/include -fPIC -O3 -std=c++17 -Wno-switch-bool -DENABLE_BF16 -DENABLE_FP8 -DENABLE_FP4 -c /home/scratch.dmoss_gpu_1/repos/squirtle/flashinfer/csrc/nv_internal/tensorrt_llm/thop/fp4Op.cpp -o fp4_quantization_90/fp4Op.o 
E           /home/scratch.dmoss_gpu_1/repos/squirtle/flashinfer/csrc/nv_internal/tensorrt_llm/thop/fp4Op.cpp: In function β€˜at::Tensor torch_ext::mxfp4_dequantize_host(at::Tensor, at::Tensor, int64_t)’:
E           /home/scratch.dmoss_gpu_1/repos/squirtle/flashinfer/csrc/nv_internal/tensorrt_llm/thop/fp4Op.cpp:405:44: error: β€˜__CUDACC_VER_MAJOR__’ was not declared in this scope
E             405 |   std::cout << "Current CUDA version: " << __CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 << std::endl;
E                 |                                            ^~~~~~~~~~~~~~~~~~~~
E           /home/scratch.dmoss_gpu_1/repos/squirtle/flashinfer/csrc/nv_internal/tensorrt_llm/thop/fp4Op.cpp:405:75: error: β€˜__CUDACC_VER_MINOR__’ was not declared in this scope
E             405 |   std::cout << "Current CUDA version: " << __CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 << std::endl;
E                 |                                                                           ^~~~~~~~~~~~~~~~~~~~
E           ninja: build stopped: subcommand failed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes c++ files will be compiled with gcc instead of nvcc.
Considering these file includes cuda headers, can we rename it to fp4Op.cu instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, this was actually fixed in different PR. I've removed this change

@djmmoss djmmoss changed the title feat:enable fp8 blockscale moe for fused cultass feat:enable fp8 blockscale moe for fused cultass for sm90 Sep 30, 2025
@djmmoss
Copy link
Contributor Author

djmmoss commented Oct 1, 2025

@yzh119 ready for final review

@@ -0,0 +1,543 @@
/*
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How are these implementations from https://github.com/deepseek-ai/DeepGEMM

We already have deepgemm kernels

If there are not significant changes, we should unify them (maybe not in this PR).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tensorrt_llm/deep_gemm files are generated here: https://github.com/NVIDIA/TensorRT-LLM/tree/main/cpp/tensorrt_llm/deep_gemm

I checked the deepgemm kernels that have already been integrated into flashinfer, currently they only seem to work on blackwell (this PR is for hopper). I can look in a follow-up PR into merging the two.

@yzh119
Copy link
Collaborator

yzh119 commented Oct 7, 2025

/run bot

@yongwww
Copy link
Collaborator

yongwww commented Oct 7, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !69 has been created, and the CI pipeline #36196253 is currently running. I'll report back once the pipeline job completes.

@yongwww
Copy link
Collaborator

yongwww commented Oct 7, 2025

/run bot

it is bot run @yzh119 :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants