Skip to content

[feat] Add optimized clusters TopK kernel and DeepSeekV3.2 sparse indexer fused kernel#2814

Draft
Aalanli wants to merge 4 commits intoflashinfer-ai:mainfrom
Aalanli:alin_topk
Draft

[feat] Add optimized clusters TopK kernel and DeepSeekV3.2 sparse indexer fused kernel#2814
Aalanli wants to merge 4 commits intoflashinfer-ai:mainfrom
Aalanli:alin_topk

Conversation

@Aalanli
Copy link

@Aalanli Aalanli commented Mar 18, 2026

This PR specifically targets low-latency decode stage, where number of sms are abundant with respect to the problem size.

The key optimization for topK is making use of CTA clusters on top of a fused radix algorithm. For the sparse indexer another optimization is added where we fuse the computation of the histogram of the first bin to the epilogue of the gemm kernel. This eliminates a pass of reading over logits, so now we only read one row of logits in total, the rest is cached in shared memory for later stages. The optimization does not impact the gemm kernel as it is moved out of the critical path of the epilogue, while providing some speedups for the subsequent topK.

Comparison against flashinfer's topK

Note there is a minor difference between the two ops as flashinfer's existing topK also writes the logit values in addition to the index, but this part should be negligible.

> python benchmarks/bench_topk.py --op dsv3_topk

====================================================================================================
dsv3_topk: DeepSeek v3 sparse-attention top-K shapes (k=2048, float32)
  fast_topk_clusters: histogram-driven radix top-K (variable-length aware)
  flashinfer top_k:   generic radix top-K (padded input)
====================================================================================================
 batch    seq_len |   fast_topk_clusters   flashinfer_top_k    Speedup
--------------------------------------------------------------------------------
     1       1024 |               1.79us             7.39us      4.13x
     1       4096 |              10.59us            12.34us      1.16x
     1       8192 |              10.85us            14.93us      1.38x
     1      32768 |              11.87us            27.33us      2.30x
     1      40960 |              12.90us            37.34us      2.90x
     2       1024 |               1.82us             7.87us      4.32x
     2       4096 |              10.66us            11.15us      1.05x
     2       8192 |              10.72us            14.90us      1.39x
     2      32768 |              11.81us            25.10us      2.13x
     2      40960 |              12.77us            37.89us      2.97x
     4       1024 |               1.82us             8.00us      4.39x
     4       4096 |              10.78us            13.23us      1.23x
     4       8192 |              10.78us            15.33us      1.42x
     4      32768 |              11.84us            29.60us      2.50x
     4      40960 |              12.86us            29.54us      2.30x
    32       1024 |               1.95us             8.16us      4.18x
    32       4096 |              13.31us            13.47us      1.01x
    32       8192 |              13.57us            15.49us      1.14x
    32      32768 |              15.46us            29.86us      1.93x
    32      40960 |              16.99us            29.82us      1.76x
    64       1024 |               2.75us             8.38us      3.05x
    64       4096 |              24.48us            13.50us      0.55x
    64       8192 |              25.44us            15.78us      0.62x
    64      32768 |              29.12us            29.90us      1.03x
    64      40960 |              31.97us            30.21us      0.94x

For high batch size the algorithm would need to be adjusted as increasing row-parallelism via clusters no longer makes sense. Additionally, the sm scheduling portion for the gemm kernel would need adjustment. For low-batch, every sm processes one q block but for higher batch it makes sense to use a sm persistent schedule that lets every sm process multiple q blocks, which may introduce overhead for low batch size.

Comparison against deep gemm + flashinfer topK

This is the entire indexer where the deepgemmm_ref uses deepgemm's fp8_paged_mqa_logits with flashinfer's topK.

> python benchmarks/bench_dsv3_sparse_indexer.py --pdl

====================================================================================================
dsv3_sparse_indexer: DeepSeek v3 sparse-attention end-to-end (k=2048, fp8) PDL=on
  mqa_fused:     FlashInfer fused logits + histogram top-K
  mqa_non_fused: FlashInfer separate logits + radix top-K
  deepgemm_ref:  deep_gemm fp8_paged_mqa_logits + flashinfer top_k
====================================================================================================
 batch    seq_len |    mqa_fused   mqa_non_fused   deepgemm_ref    speedup
----------------------------------------------------------------------------------
     1       1024 |       6.82us          6.46us        11.14us      1.63x
     1       4096 |      15.87us         16.16us        16.22us      1.02x
     1       8192 |      15.78us         16.10us        18.88us      1.20x
     1      32768 |      17.20us         17.73us        29.54us      1.72x
     1      40960 |      18.66us         18.82us        39.49us      2.12x
     2       1024 |       7.71us          6.66us        11.55us      1.50x
     2       4096 |      15.65us         16.10us        16.13us      1.03x
     2       8192 |      16.26us         16.45us        19.52us      1.20x
     2      32768 |      18.14us         18.88us        30.08us      1.66x
     2      40960 |      19.65us         20.67us        41.66us      2.12x
     4       1024 |       8.13us          7.07us        11.60us      1.43x
     4       4096 |      15.71us         16.13us        17.46us      1.11x
     4       8192 |      16.90us         17.15us        19.82us      1.17x
     4      32768 |      20.67us         20.96us        31.71us      1.53x
     4      40960 |      21.12us         22.02us        37.25us      1.76x
     8       1024 |       8.22us          7.10us        12.22us      1.49x
     8       4096 |      16.77us         17.12us        18.11us      1.08x
     8       8192 |      17.50us         17.86us        20.99us      1.20x
     8      32768 |      24.03us         24.66us        35.26us      1.47x
     8      40960 |      25.73us         26.54us        40.73us      1.58x
    32       1024 |       8.77us          8.19us        13.06us      1.49x
    32       4096 |      21.86us         22.40us        21.15us      0.97x
    32       8192 |      25.14us         25.76us        25.74us      1.02x
    32      32768 |      46.99us         47.86us        58.98us      1.26x
    32      40960 |      54.70us         56.10us        66.43us      1.21x
    64       1024 |      10.82us          9.82us        14.24us      1.32x
    64       4096 |      34.85us         36.82us        24.56us      0.70x
    64       8192 |      42.69us         44.96us        34.03us      0.80x
    64      32768 |      85.10us         87.90us        84.48us      0.99x
    64      40960 |      98.69us        101.58us        97.10us      0.98x

Comparison against VLLM's indexer

image

Individual Operator comparison

Epilogue fusion does not affect gemm performance

image

TopK gets speedup

image

Summary by CodeRabbit

  • New Features

    • Added histogram-based MQA top-K indexer with fused and non-fused execution paths, SM mapping, and PD-enabled launch support.
    • Exposed high-performance top-K cluster kernels and Python bindings for runtime use.
    • Added JIT module registration and package exports for the new indexer.
    • Added end-to-end benchmarking scripts for sparse indexer and top-K paths.
  • Tests

    • Added comprehensive tests validating logits and top-K outputs against reference implementations.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 18, 2026

Important

Review skipped

Draft detected.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: c5327247-dc1d-4666-9b1e-289e6824d8e1

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds a new MQA histogram-based DeepSeek v3 sparse-attention indexer: CUDA kernels (logits, fused epilogue, histogram-based top-K), TVM-FFI bindings, Python JIT wrappers and API, benchmarks, and tests; supports PDL-enabled launches and optional deep_gemm reference timing.

Changes

Cohort / File(s) Summary
Benchmarks
benchmarks/bench_dsv3_sparse_indexer.py, benchmarks/bench_topk.py
New end-to-end DSV3 sparse-indexer benchmark and added dsv3_topk benchmarking path (flashinfer/sglang comparisons, deep_gemm optional). CLI options for batch/seq sizes, PDL, CUPTI, CUDA graphs; reports per-path timings and speedups.
Python API & JIT
flashinfer/__init__.py, flashinfer/aot.py, flashinfer/jit/mqa_histogram.py, flashinfer/mqa_histogram.py
New JIT spec and module loader for MQA histogram kernels; exposes get_mqa_metadata, mqa_topk_indexer, and mqa_topk_indexer_non_fused; registers real and fake ops and auto-detects SM mapping.
CUDA Host Bindings
csrc/flashinfer_mqa_histogram_binding.cu
TVM-FFI entry points added: mqa_topk_indexer, get_mqa_metadata, fast_topk_clusters_fused, fast_topk_clusters, mqa_logits, mqa_logits_fused with input validation and stream handling.
CUDA Kernels — MQA Logits & Epilogue
csrc/mqa_v2.cu, csrc/mqa_v2_hist.cu, csrc/mqa_metadata.cu
New templated kernels and launchers for MQA logits, fused v3 epilogue (logits + histogram), and metadata SM mapping; per-SM scheduling, tcgen05 MMA usage, PDL-aware launch paths.
CUDA Kernels — TopK Clustering
csrc/fast_topk_clusters.cu, src/csrc/topk_clusters_pre_hist.cu
New fast_topk_clusters implementations including vectorized processing, per-cluster histogram pruning, caching, multi-phase reductions, and a fused prologue launcher (PDL-aware).
CUDA Headers / Utilities
include/flashinfer/mqa_histogram/common.cuh, include/flashinfer/mqa_histogram/tcgen05_utils.cuh
Low-level CUDA utilities: warp/CTA helpers, mbarrier primitives, kernel-smem setup helpers, tcgen05-specific load/MMA helpers, and conversion utilities used by kernels.
Tests
tests/attention/test_mqa_histogram.py
End-to-end tests with synthetic FP8 KV cache generation, pure-PyTorch reference dequant/top-K, and validation for fused and non-fused indexers (PDL flag coverage).

Sequence Diagram(s)

sequenceDiagram
    participant User as Python Caller
    participant PyAPI as flashinfer.mqa_histogram
    participant Runtime as JIT Module
    participant Device as CUDA Device
    participant Meta as Metadata Kernel
    participant Epilogue as MQA Fused Epilogue
    participant TopK as FastTopK Clustering

    User->>PyAPI: call mqa_topk_indexer(q, k_cache, weights, seq_lens,...)
    PyAPI->>Runtime: ensure module built & buffers allocated (logits, histogram, indices)
    PyAPI->>Meta: launch get_mqa_metadata(seq_lens) -> sm_map
    PyAPI->>Epilogue: launch_mqa_v3_fused_epilogue(q, k_cache, weights, ..., sm_map)
    Epilogue->>Device: tcgen05 MMA loads, compute logits + update histogram
    Epilogue-->>PyAPI: logits + histogram ready on device
    PyAPI->>TopK: fast_topk_clusters_fused(logits, histogram, seq_lens)
    TopK->>Device: histogram-based pruning, caching, emit top-K indices
    TopK-->>PyAPI: return indices (and optionally logits)
    PyAPI-->>User: (logits, indices)
Loading
sequenceDiagram
    participant User as Python Caller
    participant PyAPI as flashinfer.mqa_histogram
    participant Runtime as JIT Module
    participant Device as CUDA Device
    participant Meta as Metadata Kernel
    participant Logits as MQA Logits Kernel
    participant TopK as FastTopK Clustering

    User->>PyAPI: call mqa_topk_indexer_non_fused(q, k_cache, weights, seq_lens,...)
    PyAPI->>Runtime: ensure module & allocate logits/indices
    PyAPI->>Meta: launch get_mqa_metadata(seq_lens) -> sm_map
    PyAPI->>Logits: launch_mqa_logits(q, k_cache, weights, ..., sm_map)
    Logits->>Device: compute logits only
    Logits-->>PyAPI: logits buffer populated
    PyAPI->>TopK: launch_fast_topk_clusters(logits, indices, seq_lens)
    TopK->>Device: compute top-K over logits
    TopK-->>PyAPI: indices returned
    PyAPI-->>User: (logits, indices)
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Possibly related PRs

Suggested labels

ready, op: moe-routing

Suggested reviewers

  • jiahanc
  • bkryu
  • cyx-6
  • yzh119
  • jimmyzho
  • nvmbreughe
  • kahyunnam

Poem

🐰 I hopped in CUDA fields so wide,

Histograms lit where logits hide.
I danced through kernels, top-K in tow,
PDL and fused paths in a single glow.
Hooray — sparse attention, faster we go! ✨

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (2 warnings)

Check name Status Explanation Resolution
Description check ⚠️ Warning The PR description does not follow the provided template structure with required sections like Related Issues, Pre-commit Checks, and Tests checklist items. Update description to include all template sections: Related Issues, Pre-commit Checks (with checkboxes), Tests (with checkboxes), and Reviewer Notes.
Docstring Coverage ⚠️ Warning Docstring coverage is 38.16% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely describes the main features: optimized clusters TopK kernel and DeepSeekV3.2 sparse indexer fused kernel.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
📝 Coding Plan
  • Generate coding plan for human review comments

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces significant performance enhancements for the low-latency decode stage, particularly for DeepSeek v3 sparse attention. It achieves this by implementing a highly optimized TopK kernel that utilizes CTA clusters and a fused radix approach, alongside a novel sparse indexer kernel that integrates histogram computation directly into the GEMM epilogue. These changes aim to reduce computational overhead and memory access, leading to substantial speedups in critical inference operations.

Highlights

  • Optimized TopK Kernel: Introduced an optimized fast_topk_clusters kernel that leverages CTA clusters on top of a fused radix algorithm, specifically targeting DeepSeek v3 sparse-attention shapes. This provides significant speedups for low-latency decode stages, especially at lower batch sizes.
  • Fused Sparse Indexer Kernel: Implemented a new fused kernel for the DeepSeekV3.2 sparse indexer. This optimization integrates the computation of the first bin's histogram directly into the epilogue of the GEMM kernel, eliminating a separate pass over logits and reducing memory reads.
  • Performance Improvements: Benchmarking shows the new fast_topk_clusters kernel achieves up to 4.32x speedup over FlashInfer's generic top_k for certain configurations. The fused sparse indexer also demonstrates speedups, up to 2.12x, compared to a non-fused approach and deep_gemm reference.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an optimized clusters TopK kernel and a fused kernel for the DeepSeekV3.2 sparse indexer, targeting low-latency decode stages. The changes include a new CUDA kernel (csrc/fast_topk_clusters.cu) and modifications to benchmark scripts to evaluate the performance of the new kernels. The review focuses on correctness and potential issues in the CUDA code and benchmark scripts.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 14

🧹 Nitpick comments (1)
tests/attention/test_mqa_histogram.py (1)

130-148: Please exercise pdl_enabled=True at least once.

The fused kernel has a separate PDL completion path, but this matrix fixes pdl_enabled to [False]. A single SM100a smoke case for True would keep that advertised path from regressing unnoticed.

Possible change
-@pytest.mark.parametrize("pdl_enabled", [False])
+@pytest.mark.parametrize("pdl_enabled", [False, True])
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/test_mqa_histogram.py` around lines 130 - 148, The test fixes
pdl_enabled to [False], skipping the fused PDL completion path; update the
pytest parametrization in test_mqa_topk_indexer so that pdl_enabled includes
True at least once (e.g., [False, True] or add a separate test case) to exercise
the PDL path; locate the test function test_mqa_topk_indexer and the call to
mqa_topk_indexer (pdl_enabled=...) and change the
`@pytest.mark.parametrize`("pdl_enabled", [False]) line to include True so the
fused kernel's PDL completion is validated.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@benchmarks/bench_dsv3_sparse_indexer.py`:
- Around line 10-14: The usage advertises --compare-deepgemm but argparse never
defines it and compare_deepgemm is forced to HAS_DEEP_GEMM; add a CLI flag
(e.g., parser.add_argument("--compare-deepgemm", action="store_true") or
store_false as appropriate) and use args.compare_deepgemm when initializing the
compare_deepgemm variable instead of always using HAS_DEEP_GEMM, and update the
other occurrence(s) where compare_deepgemm is referenced (the main argument
parsing block and the loop that sets/uses compare_deepgemm) so the flag actually
enables/disables the DeepGEMM comparison.

In `@csrc/fast_topk_clusters.cu`:
- Around line 428-438: The extern_shared_mem calculation (extern_shared_mem)
grows with num_cached but the setup_kernel_smem_once calls for
fast_topk_clusters_kernel<N> are hard-coded to a fixed 4096*16+2048*4
shared-memory size, which can mismatch and cause kernel launch failures when
num_cached is large; fix by computing the required dynamic shared memory based
on num_cached and TopK, clamp it to the kernel's opted-in maximum, and call
setup_kernel_smem_once<fast_topk_clusters_kernel<...>>(required_smem) for each
instantiated kernel (fast_topk_clusters_kernel<1>, <2>, <4>, <8>, <16>), or add
a guard that prevents registering more than the kernel's max shared memory
budget when num_cached exceeds the supported value.
- Around line 330-336: The loop compares a full 32-bit payload from
s_cached_logit_bits to an 8-bit threshold_bin, causing missed matches in the
final refill pass; change the extraction to mask the low byte (e.g., int bin =
s_cached_logit_bits[i] & 0xFF) before comparing to threshold_bin so only the
8-bit bucket is compared; update the loop that uses s_cached_logit_bits,
cached_idx, threshold_bin and s_topk_inds accordingly to ensure remaining
candidates are correctly detected and added.
- Around line 440-471: The switch over num_clusters only covers {1,2,4,8,16} and
silently falls through for other values leaving indices unmodified; add a
default branch in the switch to reject unsupported num_clusters explicitly
(e.g., log an error or set an error return code and return/throw), referencing
the fast_topk_clusters_kernel launch sites and the indices/output buffer so
callers won't get stale results; ensure the default path clearly documents the
unsupported value (num_clusters) and exits the function consistently with the
surrounding error-handling convention.
- Around line 224-230: The code drops elements when cached_offset >= num_cached
which makes the TopK path approximate; instead of silently discarding, always
account for the item's contribution to the next-pass histogram: keep the
existing branch that writes s_cached_indices and s_cached_logit_bits when
cached_offset < num_cached, but in the else path still perform the histogram
update (atomicAdd(shared_hist[1] + ((bits >> 16) & 0xff), 1)) so discarded items
are counted; apply this same change to the other symmetric block around
s_cached_*/shared_num_cached_count (the second occurrence) so all overflow
candidates still update shared_hist even when not stored in s_cached_* arrays.

In `@csrc/mqa_v2_hist.cu`:
- Around line 509-535: The kernel reads only sm_mapping[sm_id] instead of the
per-SM tile entry written by launch_mqa_kernel_metadata, so it ignores all tiles
beyond the first num_sms; replace the read of sm_mapping in the mqa_v2_hist
kernel (the assignment to sm_map currently using sm_mapping[sm_id]) with the
correct index that consumes every tile: use sm_mapping[sm_id * sm_multiple +
sm_loc] (and keep the following __shfl_sync broadcasts), and apply the same fix
to the corresponding block around lines 579-633 where the same pattern appears;
this ensures the kernel consumes all sm_mapping entries created by
launch_mqa_kernel_metadata.
- Around line 542-560: The function init_tensormap_nd currently ends with
assert(err == CUDA_SUCCESS) which is compiled out in release builds; replace
this assert with runtime error handling that checks the return value from
cuTensorMapEncodeTiled and reports/throws on failure (e.g., call the existing
checkCu() helper used elsewhere or throw a std::runtime_error with a message
containing the cuTensorMapEncodeTiled error code and context), so the tensor-map
descriptor is not left invalid at runtime; update init_tensormap_nd to perform
this check after the cuTensorMapEncodeTiled call and propagate/throw an
informative error if err != CUDA_SUCCESS.

In `@csrc/topk_clusters_pre_hist.cu`:
- Around line 297-303: Refactor the comparison inside the refill loop to use
only the low 8 bits of s_cached_logit_bits when matching threshold_bin: in the
for loop that iterates i over buf_len (using s_cached_logit_bits,
s_cached_indices, threshold_bin, s_topk_inds, TopK, shared_final_idx_count),
mask the full 32-bit bin (currently assigned from s_cached_logit_bits[i]) with
0xFF (e.g., compute an 8-bit bin8) and compare bin8 == threshold_bin before
doing the atomicAdd and storing cached_idx so the final refill pass correctly
matches the last 8-bit radix bucket.
- Around line 217-223: The current TopK path silently drops elements when
cached_offset >= num_cached (in the block that writes
s_cached_indices/s_cached_logit_bits and atomically increments
shared_num_cached_count), which makes the path approximate; change it to detect
and record overflow: after atomicAdd(&shared_num_cached_count[0], 1) check if
cached_offset < num_cached to keep the existing cache write, otherwise set a
shared overflow flag (e.g., shared_overflow[0] = 1 using atomicOr) and ensure
the histogram still accounts for the overflowed element (increment shared_hist
directly or mark the bin as overflowed); add the same overflow handling to the
other identical block (the one around s_cached_* at the later location) and wire
a fallback path that, when shared_overflow is set, forces the algorithm to
process that bin in full (non-cached) mode in the next pass so no elements are
silently dropped.

In `@flashinfer/mqa_histogram.py`:
- Around line 233-235: The allocation for logits uses max_model_len without
validating it against actual sequence lengths, risking out-of-bounds writes when
seq_lens[b] > max_model_len; update the allocation logic in mqa_histogram.py
(where logits and indices are created) to compute the required_row_len =
max(max_model_len, seq_lens.max()) or clamp seq_lens to max_model_len and
allocate logits with required_row_len (and similarly adjust indices allocation),
and apply the same fix at the other occurrence (the block around the second
allocation at lines ~275-278) so kernel writes cannot exceed the allocated row
length.
- Around line 192-205: The public API get_mqa_metadata should be gated with the
`@backend_requirement` decorator and the module should expose/support helper
predicates is_compute_capability_supported(cc) and is_backend_supported() so
callers fail fast on unsupported devices; update get_mqa_metadata (and the other
public wrappers in the same block: the functions spanning the region around
get_mqa_metadata and the subsequent APIs between the referenced lines) to have
`@backend_requirement` applied, implement or import
is_compute_capability_supported(cc) and is_backend_supported(), and ensure the
decorator uses those helpers to check device compute capability before entering
get_mqa_histogram_module().get_mqa_metadata to prevent JIT/build/launch on
unsupported hardware.

In `@include/flashinfer/mqa_histogram/common.cuh`:
- Around line 107-130: The WLMS helper currently has an unconditional return
that prevents the warp-aggregated path from ever executing; remove that return
and restructure the function so that when active_mask is true it attempts the
warp-aggregation branch (using __ballot_sync, getLaneMaskLt(), __popc and
atomicAdd on shared_hist + val) and falls back to per-lane atomicAdd only when
the warp-aggregation condition isn't met or active_mask is false; update logic
in WLMS(uint8_t val, int *shared_hist, bool active_mask = true) to first compute
warpFlags = __ballot_sync(0xffffffff, active_mask), derive per-lane uniqueness
with the existing bit tests and getLaneMaskLt() check, and call
atomicAdd(shared_hist + val, __popc(warpFlags)) only from the first lane (bits
== 0) while preserving the previous per-lane atomicAdd behavior as a fallback.
- Around line 8-16: The CUDA_CHECK macro should not call exit(); change it to
propagate failures so callers like setup_kernel_smem_once() can handle errors:
replace the current CUDA_CHECK (used around cudaFuncSetAttribute and elsewhere)
with a variant that captures cudaError_t and returns or throws it (e.g., return
err from the enclosing function or rethrow as a cudaError_t/exception), then
update setup_kernel_smem_once() to use that non-fatal check and propagate the
error to the Python-facing caller instead of terminating the process; apply the
same change for the other occurrences mentioned (lines ~145-153) so all
cudaFuncSetAttribute and similar calls surface errors to callers.

In `@include/flashinfer/mqa_histogram/tcgen05_utils.cuh`:
- Around line 162-177: The function template tcgen05_ld_32x32b has a
non-dependent static_assert inside a discarded else branch which is ill-formed
for nvcc/C++ even if only valid WIDTHs are instantiated; move the static_assert
out of the else branch to function scope and make it dependent on the template
parameter (for example assert (WIDTH==2 || WIDTH==4 || WIDTH==8 || WIDTH==16 ||
WIDTH==32 || WIDTH==64)) so the assertion is only evaluated per-instantiation
and not unconditionally; keep the existing if constexpr chain (calling
tcgen05_ld_32x32b_x2/_x4/_x8/_x16/_x32/_x64) and replace the final
else/static_assert pattern with a single dependent static_assert after the
chain.

---

Nitpick comments:
In `@tests/attention/test_mqa_histogram.py`:
- Around line 130-148: The test fixes pdl_enabled to [False], skipping the fused
PDL completion path; update the pytest parametrization in test_mqa_topk_indexer
so that pdl_enabled includes True at least once (e.g., [False, True] or add a
separate test case) to exercise the PDL path; locate the test function
test_mqa_topk_indexer and the call to mqa_topk_indexer (pdl_enabled=...) and
change the `@pytest.mark.parametrize`("pdl_enabled", [False]) line to include True
so the fused kernel's PDL completion is validated.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: f66ed369-0139-4eea-919c-24392c5c644a

📥 Commits

Reviewing files that changed from the base of the PR and between a83ec99 and 4dbfaea.

📒 Files selected for processing (15)
  • benchmarks/bench_dsv3_sparse_indexer.py
  • benchmarks/bench_topk.py
  • csrc/fast_topk_clusters.cu
  • csrc/flashinfer_mqa_histogram_binding.cu
  • csrc/mqa_metadata.cu
  • csrc/mqa_v2.cu
  • csrc/mqa_v2_hist.cu
  • csrc/topk_clusters_pre_hist.cu
  • flashinfer/__init__.py
  • flashinfer/aot.py
  • flashinfer/jit/mqa_histogram.py
  • flashinfer/mqa_histogram.py
  • include/flashinfer/mqa_histogram/common.cuh
  • include/flashinfer/mqa_histogram/tcgen05_utils.cuh
  • tests/attention/test_mqa_histogram.py

@kahyunnam
Copy link
Collaborator

/bot run

@kahyunnam
Copy link
Collaborator

@Aalanli what compute capability did you optimize for / on what gpu did you collect these benchmarking results?

@flashinfer-bot
Copy link
Collaborator

GitLab MR !428 has been created, and the CI pipeline #46461374 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

♻️ Duplicate comments (7)
csrc/mqa_v2_hist.cu (2)

145-146: ⚠️ Potential issue | 🟡 Minor

Fix typo in register limit constant.

Same issue as in csrc/mqa_v2.cu: uses 65636 instead of the correct 65536.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/mqa_v2_hist.cu` around lines 145 - 146, The static_assert that checks
register usage uses the incorrect constant 65636; update the assertion in the
static_assert statement that references num_epilogue_threads, math_nreg,
num_special_warps, and special_nreg to use the correct hardware limit 65536
instead of 65636 so the check reads "... <= 65536" and preserves the existing
error message "too many registers".

503-509: ⚠️ Potential issue | 🟠 Major

Replace assert with runtime error handling for cuTensorMapEncodeTiled.

Same issue as in csrc/mqa_v2.cu: the assertion compiles out with NDEBUG, leaving invalid tensor-map descriptors undetected.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/mqa_v2_hist.cu` around lines 503 - 509, The call to
cuTensorMapEncodeTiled stores its return in err but uses assert(err ==
CUDA_SUCCESS), which can be compiled out; replace the assert with explicit
runtime error handling: check if err != CUDA_SUCCESS after
cuTensorMapEncodeTiled (the same pattern used in csrc/mqa_v2.cu), log or format
a clear error message including the numeric err and a short context string
(e.g., "cuTensorMapEncodeTiled failed"), and then abort/throw/return an error
code as appropriate for this module to prevent proceeding with an invalid
tensor-map descriptor; update the code around the err variable and the
cuTensorMapEncodeTiled call accordingly.
include/flashinfer/mqa_histogram/tcgen05_utils.cuh (1)

152-169: ⚠️ Potential issue | 🟠 Major

Move static_assert out of the discarded else branch.

static_assert(false, ...) inside the else branch of if constexpr is non-dependent on the template parameter WIDTH, making this ill-formed per C++ standard ([temp.res]/8). nvcc will fail to compile even when only valid WIDTHs are instantiated. Place a dependent assertion at function scope instead.

🛠️ Proposed fix
 template <int WIDTH>
 __device__ inline void tcgen05_ld_32x32b(int addr, float (&tmp)[WIDTH]) {
+  static_assert(WIDTH == 2 || WIDTH == 4 || WIDTH == 8 ||
+                WIDTH == 16 || WIDTH == 32 || WIDTH == 64,
+                "WIDTH must be 2, 4, 8, 16, 32, or 64");
   if constexpr (WIDTH == 2) {
     tcgen05_ld_32x32b_x2(addr, tmp);
   } else if constexpr (WIDTH == 4) {
     tcgen05_ld_32x32b_x4(addr, tmp);
   } else if constexpr (WIDTH == 8) {
     tcgen05_ld_32x32b_x8(addr, tmp);
   } else if constexpr (WIDTH == 16) {
     tcgen05_ld_32x32b_x16(addr, tmp);
   } else if constexpr (WIDTH == 32) {
     tcgen05_ld_32x32b_x32(addr, tmp);
   } else if constexpr (WIDTH == 64) {
     tcgen05_ld_32x32b_x64(addr, tmp);
-  } else {
-    static_assert(false, "WIDTH must be 2, 4, 8, 16, 32, or 64");
   }
 }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/mqa_histogram/tcgen05_utils.cuh` around lines 152 - 169,
The static_assert inside the discarded else branch of the template function
tcgen05_ld_32x32b is non-dependent on WIDTH and can cause compilation failures;
move the check out of the discarded branch and make it dependent on the template
parameter (e.g., use a helper like always_false<WIDTH> or similar) and place a
single static_assert(always_false<WIDTH>, "WIDTH must be 2, 4, 8, 16, 32, or
64") at function scope after the if constexpr chain so only invalid WIDTH
instantiations trigger the assertion; ensure the helper template (always_false)
depends on WIDTH so the assertion is SFINAE-friendly.
flashinfer/mqa_histogram.py (2)

223-240: ⚠️ Potential issue | 🟠 Major

Gate public APIs behind @backend_requirement for SM100a.

These APIs require SM100a (Blackwell) but don't use the @backend_requirement decorator. Callers will encounter cryptic JIT or launch failures on unsupported GPUs. Add the decorator to fail fast with a clear error message.

As per coding guidelines: flashinfer/*.py: Use @backend_requirement decorator on APIs that have compute capability requirements and provide is_compute_capability_supported(cc) and is_backend_supported() methods.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/mqa_histogram.py` around lines 223 - 240, Add a backend guard to
the public API by decorating get_mqa_metadata with `@backend_requirement`, import
the decorator and supply the module-level support query functions
(is_compute_capability_supported and is_backend_supported) so callers fail fast
on non-SM100a GPUs; specifically, add the
`@backend_requirement`(is_compute_capability_supported, is_backend_supported)
decorator above the get_mqa_metadata function and ensure those two functions are
defined/accessibly imported in the same module.

268-272: ⚠️ Potential issue | 🔴 Critical

Validate that max_model_len accommodates the largest sequence.

The logits buffer is allocated with shape [batch_size, max_model_len], but kernels write up to seq_lens[b] elements per row. If any sequence exceeds max_model_len, this causes out-of-bounds writes.

🛡️ Proposed validation
+    max_seq = int(seq_lens.max().item())
+    if max_seq > max_model_len:
+        raise ValueError(
+            f"max_model_len={max_model_len} is smaller than the longest sequence {max_seq}"
+        )
     logits = torch.empty(
         batch_size, max_model_len, device=q.device, dtype=torch.float32
     )

Also applies to mqa_topk_indexer at lines 313-316.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/mqa_histogram.py` around lines 268 - 272, The logits buffer
(logits) and related index buffer (indices) are allocated using max_model_len
but kernels write up to seq_lens[b], so validate that max_model_len >=
max(seq_lens) before allocation (or reallocate using max_seq_len derived from
seq_lens); update both the block that creates logits/indices in mqa_histogram
(the variables logits, indices, q, max_model_len) and the analogous allocation
in mqa_topk_indexer to compute max_seq_len = seq_lens.max().item() (or raise a
clear error) and use that value for the second dimension to prevent
out‑of‑bounds writes. Ensure the validation uses the existing seq_lens
tensor/variable so no silent truncation occurs.
csrc/fast_topk_clusters.cu (1)

397-407: ⚠️ Potential issue | 🟠 Major

Guard against num_cached exceeding the registered shared-memory budget.

extern_shared_mem scales with num_cached, but setup_kernel_smem_once is hardcoded to the 4096-candidate footprint (4096 * 16 + 2048 * 4). Larger num_cached values will request more dynamic shared memory than configured, causing launch failures.

🛡️ Proposed guard
   int extern_shared_mem =
       (num_cached * 2 * sizeof(float) + num_cached * 2 * sizeof(int) +
        TopK * sizeof(int));  // 2 * num_cached float, 2 * num_cached int, topk int
+
+  constexpr int max_smem_budget = 4096 * 16 + 2048 * 4;
+  assert(extern_shared_mem <= max_smem_budget &&
+         "num_cached exceeds configured shared-memory budget (max 4096)");
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/fast_topk_clusters.cu` around lines 397 - 407, The code computes
extern_shared_mem from num_cached but always registers a fixed shared-memory
footprint using setup_kernel_smem_once for fast_topk_clusters_kernel<N>; guard
against num_cached exceeding that registered budget by computing the required
dynamic shared memory (extern_shared_mem) and ensuring you only call
setup_kernel_smem_once/setup_non_portable_clusters_once with a footprint >=
extern_shared_mem (or clamp num_cached to the max supported value before
computing extern_shared_mem), and if num_cached is larger emit a clear
error/return; update references around extern_shared_mem, num_cached,
setup_kernel_smem_once, setup_non_portable_clusters_once and
fast_topk_clusters_kernel so the registered smem matches the runtime requirement
or the code refuses unsupported sizes.
benchmarks/bench_dsv3_sparse_indexer.py (1)

10-14: ⚠️ Potential issue | 🟡 Minor

Wire the documented --compare-deepgemm flag.

The usage text advertises --compare-deepgemm but argparse never defines it. The benchmark always uses compare_deepgemm=HAS_DEEP_GEMM regardless of user intent.

🔧 Proposed fix
     parser.add_argument(
+        "--compare-deepgemm",
+        action="store_true",
+        help="Benchmark deep_gemm fp8_paged_mqa_logits + flashinfer top_k",
+    )
+    parser.add_argument(
         "--pdl",
         action="store_true",

And update line 273:

-                    compare_deepgemm=HAS_DEEP_GEMM,
+                    compare_deepgemm=args.compare_deepgemm and HAS_DEEP_GEMM,
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_dsv3_sparse_indexer.py` around lines 10 - 14, Argparse isn’t
exposing the documented --compare-deepgemm flag, so the script always uses
compare_deepgemm = HAS_DEEP_GEMM; add a boolean flag to the ArgumentParser
(e.g., parser.add_argument("--compare-deepgemm", action="store_true",
help="Compare against DeepGEMM")) and then use the parsed value
(args.compare_deepgemm) instead of defaulting to HAS_DEEP_GEMM when setting
compare_deepgemm (and update the assignment near where compare_deepgemm is set,
referenced as compare_deepgemm and HAS_DEEP_GEMM).
🧹 Nitpick comments (2)
tests/attention/test_mqa_histogram.py (2)

94-96: Prefix unused variable with underscore.

num_heads is unpacked but never used. Prefix it with _ to indicate intentional discard.

-    batch_size, num_heads, head_dim = q_fp8.shape
+    batch_size, _num_heads, head_dim = q_fp8.shape
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/test_mqa_histogram.py` around lines 94 - 96, In
_dsa_topk_indexer change the unused unpacked variable `num_heads` to a discarded
name (e.g., `_num_heads` or `_`) so the tuple unpacking reads `batch_size,
_num_heads, head_dim = q_fp8.shape`, indicating the value is intentionally
unused and avoiding linter warnings; update only the variable name in that
assignment inside the _dsa_topk_indexer function.

159-165: Rename ambiguous variable l.

The variable name l is visually ambiguous (looks like digit 1). Rename to seq_len_i or similar for clarity.

-        l = int(seq_lens[i])
+        seq_len_i = int(seq_lens[i])

         logit_ref = logits_ref[i]
-        logit_0 = logits_0[i][:l]
-        logit_1 = logits_1[i][:l]
+        logit_0 = logits_0[i][:seq_len_i]
+        logit_1 = logits_1[i][:seq_len_i]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/test_mqa_histogram.py` around lines 159 - 165, Rename the
ambiguous loop variable `l` to a clearer name like `seq_len_i` in the loop that
iterates over range(B) where `prefix`, `seq_lens`, `logits_ref`, `logits_0`, and
`logits_1` are used; update its uses (currently `l = int(seq_lens[i])` and
slicing `logit_0 = logits_0[i][:l]`, `logit_1 = logits_1[i][:l]`) to `seq_len_i`
so the intent is clear and avoids confusion with the digit "1".
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@csrc/fast_topk_clusters.cu`:
- Around line 307-322: In the final refill pass (the t == 3 branch),
s_cached_logit_bits holds the full 32-bit payload but threshold_bin is an 8-bit
bucket value, so change the comparison to mask the low byte before comparing:
when reading int bin = s_cached_logit_bits[i] inside the loop, mask it (bin &
0xFF) or otherwise extract the lowest byte and compare that to threshold_bin;
keep the existing atomicAdd into shared_final_idx_count and the TopK-bound write
into s_topk_inds unchanged so the refill will correctly match and collect
candidates.

In `@csrc/flashinfer_mqa_histogram_binding.cu`:
- Around line 255-270: fast_topk_clusters currently casts num_clusters to int
and calls launch_fast_topk_clusters without validating it; add a validation step
in fast_topk_clusters to ensure num_clusters is one of the supported values
{1,2,4,8,16} before calling launch_fast_topk_clusters (use the num_clusters
variable and the launch_fast_topk_clusters call site to locate the change) and
if it is not supported, return/throw a clear error (or set indices to a safe
default) to avoid silently skipping the kernel and leaving indices
uninitialized; keep all other parameter checks (check_logits, check_indices,
check_seq_lens) intact.

In `@csrc/mqa_v2.cu`:
- Around line 141-142: The static_assert that checks register usage (involving
num_epilogue_threads, math_nreg, num_special_warps, and special_nreg) uses the
wrong constant 65636; change that literal to the correct GPU register limit
65536 so the assertion reads that combined registers <= 65536, keeping the same
message "too many registers".
- Around line 479-486: The assert(err == CUDA_SUCCESS) after
cuTensorMapEncodeTiled is unsafe because it can compile out; replace it with
proper runtime error handling: check the returned err from
cuTensorMapEncodeTiled and call the existing helper (e.g., checkCu(err,
"cuTensorMapEncodeTiled") used in csrc/xqa/tensorMap.cpp) or throw a
std::runtime_error with a clear message including the cuError name/number and
context (e.g., which tmap/ptr operation failed); ensure any needed cleanup or
early return follows the error path so an invalid tensor-map descriptor is never
used.

---

Duplicate comments:
In `@benchmarks/bench_dsv3_sparse_indexer.py`:
- Around line 10-14: Argparse isn’t exposing the documented --compare-deepgemm
flag, so the script always uses compare_deepgemm = HAS_DEEP_GEMM; add a boolean
flag to the ArgumentParser (e.g., parser.add_argument("--compare-deepgemm",
action="store_true", help="Compare against DeepGEMM")) and then use the parsed
value (args.compare_deepgemm) instead of defaulting to HAS_DEEP_GEMM when
setting compare_deepgemm (and update the assignment near where compare_deepgemm
is set, referenced as compare_deepgemm and HAS_DEEP_GEMM).

In `@csrc/fast_topk_clusters.cu`:
- Around line 397-407: The code computes extern_shared_mem from num_cached but
always registers a fixed shared-memory footprint using setup_kernel_smem_once
for fast_topk_clusters_kernel<N>; guard against num_cached exceeding that
registered budget by computing the required dynamic shared memory
(extern_shared_mem) and ensuring you only call
setup_kernel_smem_once/setup_non_portable_clusters_once with a footprint >=
extern_shared_mem (or clamp num_cached to the max supported value before
computing extern_shared_mem), and if num_cached is larger emit a clear
error/return; update references around extern_shared_mem, num_cached,
setup_kernel_smem_once, setup_non_portable_clusters_once and
fast_topk_clusters_kernel so the registered smem matches the runtime requirement
or the code refuses unsupported sizes.

In `@csrc/mqa_v2_hist.cu`:
- Around line 145-146: The static_assert that checks register usage uses the
incorrect constant 65636; update the assertion in the static_assert statement
that references num_epilogue_threads, math_nreg, num_special_warps, and
special_nreg to use the correct hardware limit 65536 instead of 65636 so the
check reads "... <= 65536" and preserves the existing error message "too many
registers".
- Around line 503-509: The call to cuTensorMapEncodeTiled stores its return in
err but uses assert(err == CUDA_SUCCESS), which can be compiled out; replace the
assert with explicit runtime error handling: check if err != CUDA_SUCCESS after
cuTensorMapEncodeTiled (the same pattern used in csrc/mqa_v2.cu), log or format
a clear error message including the numeric err and a short context string
(e.g., "cuTensorMapEncodeTiled failed"), and then abort/throw/return an error
code as appropriate for this module to prevent proceeding with an invalid
tensor-map descriptor; update the code around the err variable and the
cuTensorMapEncodeTiled call accordingly.

In `@flashinfer/mqa_histogram.py`:
- Around line 223-240: Add a backend guard to the public API by decorating
get_mqa_metadata with `@backend_requirement`, import the decorator and supply the
module-level support query functions (is_compute_capability_supported and
is_backend_supported) so callers fail fast on non-SM100a GPUs; specifically, add
the `@backend_requirement`(is_compute_capability_supported, is_backend_supported)
decorator above the get_mqa_metadata function and ensure those two functions are
defined/accessibly imported in the same module.
- Around line 268-272: The logits buffer (logits) and related index buffer
(indices) are allocated using max_model_len but kernels write up to seq_lens[b],
so validate that max_model_len >= max(seq_lens) before allocation (or reallocate
using max_seq_len derived from seq_lens); update both the block that creates
logits/indices in mqa_histogram (the variables logits, indices, q,
max_model_len) and the analogous allocation in mqa_topk_indexer to compute
max_seq_len = seq_lens.max().item() (or raise a clear error) and use that value
for the second dimension to prevent out‑of‑bounds writes. Ensure the validation
uses the existing seq_lens tensor/variable so no silent truncation occurs.

In `@include/flashinfer/mqa_histogram/tcgen05_utils.cuh`:
- Around line 152-169: The static_assert inside the discarded else branch of the
template function tcgen05_ld_32x32b is non-dependent on WIDTH and can cause
compilation failures; move the check out of the discarded branch and make it
dependent on the template parameter (e.g., use a helper like always_false<WIDTH>
or similar) and place a single static_assert(always_false<WIDTH>, "WIDTH must be
2, 4, 8, 16, 32, or 64") at function scope after the if constexpr chain so only
invalid WIDTH instantiations trigger the assertion; ensure the helper template
(always_false) depends on WIDTH so the assertion is SFINAE-friendly.

---

Nitpick comments:
In `@tests/attention/test_mqa_histogram.py`:
- Around line 94-96: In _dsa_topk_indexer change the unused unpacked variable
`num_heads` to a discarded name (e.g., `_num_heads` or `_`) so the tuple
unpacking reads `batch_size, _num_heads, head_dim = q_fp8.shape`, indicating the
value is intentionally unused and avoiding linter warnings; update only the
variable name in that assignment inside the _dsa_topk_indexer function.
- Around line 159-165: Rename the ambiguous loop variable `l` to a clearer name
like `seq_len_i` in the loop that iterates over range(B) where `prefix`,
`seq_lens`, `logits_ref`, `logits_0`, and `logits_1` are used; update its uses
(currently `l = int(seq_lens[i])` and slicing `logit_0 = logits_0[i][:l]`,
`logit_1 = logits_1[i][:l]`) to `seq_len_i` so the intent is clear and avoids
confusion with the digit "1".

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 2f34ad89-71f6-4b38-9132-a193346b8b1b

📥 Commits

Reviewing files that changed from the base of the PR and between 4dbfaea and 742bfb9.

📒 Files selected for processing (12)
  • benchmarks/bench_dsv3_sparse_indexer.py
  • benchmarks/bench_topk.py
  • csrc/fast_topk_clusters.cu
  • csrc/flashinfer_mqa_histogram_binding.cu
  • csrc/mqa_metadata.cu
  • csrc/mqa_v2.cu
  • csrc/mqa_v2_hist.cu
  • csrc/topk_clusters_pre_hist.cu
  • flashinfer/mqa_histogram.py
  • include/flashinfer/mqa_histogram/common.cuh
  • include/flashinfer/mqa_histogram/tcgen05_utils.cuh
  • tests/attention/test_mqa_histogram.py
🚧 Files skipped from review as they are similar to previous changes (4)
  • csrc/mqa_metadata.cu
  • csrc/topk_clusters_pre_hist.cu
  • benchmarks/bench_topk.py
  • include/flashinfer/mqa_histogram/common.cuh

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #46461374: 11/20 passed

@Aalanli
Copy link
Author

Aalanli commented Mar 19, 2026

Hi @kahyunnam , I benchmarked on a bare-metal B200 node, for sm100a. Note that the topK works for both hopper and blackwell but the gemm kernel uses tcgen05 instructions so only works for blackwell. Most of the speedup I observed comes from the topK, but the histogram fusion and PDL is a slightly added boost.

This PR is mostly to get the comparisons going, the code still needs a lot of cleanup if we want to merge this.

@Aalanli Aalanli marked this pull request as draft March 19, 2026 18:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants