Skip to content

Conversation

hypdeb
Copy link
Contributor

@hypdeb hypdeb commented Oct 1, 2025

In low latency context, it is not uncommon to encounter memory bandwidth bound GEMMs with a tiny leading dimension M. These cases are currently not addressed as efficiently as they could by library implementations. To fill this gap, I propose to expose generated GEMM kernels optimized for small batch sizes, which saturate memory bandwidth to a higher degree.

The main challenge in doing so is that these GEMMs expect the weight tensor (second operand) to be pre-processed into a layout more amenable to maximizing memory bandwidth saturation. As such it is not practical to expose them under the same API as the other GEMMs, as they are not interchangeable without changing the caller's implementation. I have tentatively exposed these GEMMs as "flavored" GEMMs, by contrast with the more "vanilla" GEMMs currently available.

Summary of the changes:

  • Added cpp runner to be jitted for these new GEMMs: csrc/trtllm_flavored_gemm_runner.cu
  • A separate flashinfer/trtllm_flavored_gemm.py file containing the Python interface of the new GEMMs
  • Some stylistic refactoring of the autotuner done while understanding how it works
  • Tests
  • Benchmarks
  • Some other minor cleanups along the way. Note I will undo the extraction of fp8_utils.py as the implementations of to_fp8 differ between the places I extracted it for

Next step:

I will add more kernels for larger batch sizes. This is required because the weight matrix shuffling commits the user to this interface. Therefore, they also need efficient kernels for larger batches, which they will encounter for prefills for example, when not doing disagg.

Benchmarking results on GB200:

m=1 n=2560 k=16384 9.65 TFLOPs/s over 0.008694 ms, 4.83 TB/s
m=1 n=2560 k=32768 11.34 TFLOPs/s over 0.014797 ms, 5.67 TB/s
m=1 n=5120 k=16384 15.10 TFLOPs/s over 0.011110 ms, 7.55 TB/s
m=1 n=5120 k=32768 12.21 TFLOPs/s over 0.027491 ms, 6.10 TB/s
m=1 n=8192 k=16384 11.75 TFLOPs/s over 0.022851 ms, 5.87 TB/s
m=1 n=8192 k=32768 13.06 TFLOPs/s over 0.041114 ms, 6.53 TB/s
m=2 n=2560 k=16384 18.38 TFLOPs/s over 0.009130 ms, 4.60 TB/s
m=2 n=2560 k=32768 21.21 TFLOPs/s over 0.015821 ms, 5.31 TB/s
m=2 n=5120 k=16384 30.21 TFLOPs/s over 0.011107 ms, 7.56 TB/s
m=2 n=5120 k=32768 24.41 TFLOPs/s over 0.027491 ms, 6.11 TB/s
m=2 n=8192 k=16384 23.43 TFLOPs/s over 0.022912 ms, 5.86 TB/s
m=2 n=8192 k=32768 26.15 TFLOPs/s over 0.041056 ms, 6.54 TB/s
m=4 n=2560 k=16384 36.22 TFLOPs/s over 0.009264 ms, 4.54 TB/s
m=4 n=2560 k=32768 43.55 TFLOPs/s over 0.015408 ms, 5.45 TB/s
m=4 n=5120 k=16384 60.40 TFLOPs/s over 0.011110 ms, 7.56 TB/s
m=4 n=5120 k=32768 48.82 TFLOPs/s over 0.027494 ms, 6.11 TB/s
m=4 n=8192 k=16384 46.71 TFLOPs/s over 0.022989 ms, 5.84 TB/s
m=4 n=8192 k=32768 52.10 TFLOPs/s over 0.041216 ms, 6.52 TB/s
m=8 n=2560 k=16384 72.47 TFLOPs/s over 0.009261 ms, 4.55 TB/s
m=8 n=2560 k=32768 84.84 TFLOPs/s over 0.015821 ms, 5.32 TB/s
m=8 n=5120 k=16384 120.84 TFLOPs/s over 0.011107 ms, 7.57 TB/s
m=8 n=5120 k=32768 97.37 TFLOPs/s over 0.027568 ms, 6.10 TB/s
m=8 n=8192 k=16384 93.41 TFLOPs/s over 0.022989 ms, 5.85 TB/s
m=8 n=8192 k=32768 104.21 TFLOPs/s over 0.041216 ms, 6.52 TB/s
m=16 n=2560 k=16384 138.70 TFLOPs/s over 0.009677 ms, 4.37 TB/s
m=16 n=2560 k=32768 174.22 TFLOPs/s over 0.015408 ms, 5.48 TB/s
m=16 n=5120 k=16384 231.03 TFLOPs/s over 0.011619 ms, 7.26 TB/s
m=16 n=5120 k=32768 190.13 TFLOPs/s over 0.028237 ms, 5.97 TB/s
m=16 n=8192 k=16384 180.96 TFLOPs/s over 0.023734 ms, 5.68 TB/s
m=16 n=8192 k=32768 205.52 TFLOPs/s over 0.041795 ms, 6.44 TB/s
m=32 n=2560 k=16384 260.92 TFLOPs/s over 0.010288 ms, 4.14 TB/s
m=32 n=2560 k=32768 322.64 TFLOPs/s over 0.016640 ms, 5.11 TB/s
m=32 n=5120 k=16384 421.01 TFLOPs/s over 0.012752 ms, 6.65 TB/s
m=32 n=5120 k=32768 371.18 TFLOPs/s over 0.028928 ms, 5.85 TB/s
m=32 n=8192 k=16384 348.80 TFLOPs/s over 0.024627 ms, 5.49 TB/s
m=32 n=8192 k=32768 400.89 TFLOPs/s over 0.042854 ms, 6.30 TB/s
m=64 n=2560 k=16384 466.29 TFLOPs/s over 0.011514 ms, 3.76 TB/s
m=64 n=2560 k=32768 458.96 TFLOPs/s over 0.023395 ms, 3.69 TB/s
m=64 n=5120 k=16384 673.11 TFLOPs/s over 0.015952 ms, 5.37 TB/s
m=64 n=5120 k=32768 679.79 TFLOPs/s over 0.031590 ms, 5.40 TB/s
m=64 n=8192 k=16384 648.00 TFLOPs/s over 0.026512 ms, 5.14 TB/s
m=64 n=8192 k=32768 766.41 TFLOPs/s over 0.044832 ms, 6.06 TB/s

Copy link
Contributor

Summary of Changes

Hello @hypdeb, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces highly optimized FP8 GEMM kernels tailored for low-latency, memory-bandwidth-bound operations, particularly beneficial for scenarios involving small batch sizes in large language models. The implementation includes a new C++ CUDA runner, a Python interface, and integrates these 'flavored' GEMMs with the existing autotuning framework. The changes aim to significantly improve performance by leveraging specialized weight tensor layouts and efficient kernel selection, with initial benchmarks demonstrating substantial TFLOPs/s and bandwidth utilization.

Highlights

  • New FP8 GEMM Kernels: Introduced new generated GEMM kernels (trtllm-gen global scaled FP8 GEMMs) specifically optimized for low-latency, memory-bandwidth-bound scenarios with small leading dimension M. These kernels are designed to saturate memory bandwidth more effectively.
  • Flavored GEMM API: The new GEMMs are exposed through a dedicated 'flavored' API (trtllm_flavored_gemm) because they require specific pre-processing of the weight tensor (second operand) into a shuffled, block-layout format, making them distinct from existing 'vanilla' GEMMs.
  • Autotuner Enhancements: The autotuner received stylistic refactoring and functional improvements, including the introduction of an AutoTunerCacheHit dataclass for clearer cache handling and optimization of the profiling process to skip already cached configurations.
  • Centralized FP8 Utility: The to_float8 utility function, responsible for converting tensors to FP8, has been centralized into a new tests/utils_fp8.py file and imported where needed, reducing code duplication across benchmarks and tests.
  • Benchmarking and Testing: Comprehensive benchmarks (benchmarks/bench_trtllm_gen_flavored_gemm.py) and tests (tests/test_gemm_fp8.py) have been added to validate the correctness and performance of the new flavored GEMMs across various matrix dimensions.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new "flavored" GEMM implementation optimized for small batch sizes using FP8, which is a valuable performance enhancement. The changes include a new C++ CUDA runner, a Python interface with autotuning support, and corresponding benchmarks and tests.

The code is well-structured, and the refactoring of the autotuner and utility functions improves code quality. I have a few suggestions to enhance maintainability and clarity:

  • Address a FIXME comment and code duplication in the C++ runner.
  • Clarify the data flow of the workspace_buffer in the Python interface.
  • Improve docstrings for better user understanding.
  • Avoid shadowing Python built-ins in variable names.

Overall, this is a solid contribution. Addressing these points will make the code even better.

@hypdeb hypdeb force-pushed the gemm_tuning branch 2 times, most recently from 3e3dc6c to 25abf20 Compare October 4, 2025 10:46
Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another request, can we add gen_trtllm_low_latency_gemm_module to https://github.com/flashinfer-ai/flashinfer/blob/main/flashinfer/aot.py, so that the pre-built jit-cache will include this module.

@hypdeb hypdeb force-pushed the gemm_tuning branch 2 times, most recently from 35f268f to a4e8f34 Compare October 7, 2025 08:04
Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM, ping @aleozlx for another review on hardware compatibility.

@yzh119
Copy link
Collaborator

yzh119 commented Oct 7, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !68 has been created, and the CI pipeline #36194695 is currently running. I'll report back once the pipeline job completes.

@hypdeb
Copy link
Contributor Author

hypdeb commented Oct 8, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

@hypdeb is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants