[feat] Add optimized clusters TopK kernel and DeepSeekV3.2 sparse indexer fused kernel#2814
[feat] Add optimized clusters TopK kernel and DeepSeekV3.2 sparse indexer fused kernel#2814Aalanli wants to merge 4 commits intoflashinfer-ai:mainfrom
Conversation
|
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughAdds a new MQA histogram-based DeepSeek v3 sparse-attention indexer: CUDA kernels (logits, fused epilogue, histogram-based top-K), TVM-FFI bindings, Python JIT wrappers and API, benchmarks, and tests; supports PDL-enabled launches and optional deep_gemm reference timing. Changes
Sequence Diagram(s)sequenceDiagram
participant User as Python Caller
participant PyAPI as flashinfer.mqa_histogram
participant Runtime as JIT Module
participant Device as CUDA Device
participant Meta as Metadata Kernel
participant Epilogue as MQA Fused Epilogue
participant TopK as FastTopK Clustering
User->>PyAPI: call mqa_topk_indexer(q, k_cache, weights, seq_lens,...)
PyAPI->>Runtime: ensure module built & buffers allocated (logits, histogram, indices)
PyAPI->>Meta: launch get_mqa_metadata(seq_lens) -> sm_map
PyAPI->>Epilogue: launch_mqa_v3_fused_epilogue(q, k_cache, weights, ..., sm_map)
Epilogue->>Device: tcgen05 MMA loads, compute logits + update histogram
Epilogue-->>PyAPI: logits + histogram ready on device
PyAPI->>TopK: fast_topk_clusters_fused(logits, histogram, seq_lens)
TopK->>Device: histogram-based pruning, caching, emit top-K indices
TopK-->>PyAPI: return indices (and optionally logits)
PyAPI-->>User: (logits, indices)
sequenceDiagram
participant User as Python Caller
participant PyAPI as flashinfer.mqa_histogram
participant Runtime as JIT Module
participant Device as CUDA Device
participant Meta as Metadata Kernel
participant Logits as MQA Logits Kernel
participant TopK as FastTopK Clustering
User->>PyAPI: call mqa_topk_indexer_non_fused(q, k_cache, weights, seq_lens,...)
PyAPI->>Runtime: ensure module & allocate logits/indices
PyAPI->>Meta: launch get_mqa_metadata(seq_lens) -> sm_map
PyAPI->>Logits: launch_mqa_logits(q, k_cache, weights, ..., sm_map)
Logits->>Device: compute logits only
Logits-->>PyAPI: logits buffer populated
PyAPI->>TopK: launch_fast_topk_clusters(logits, indices, seq_lens)
TopK->>Device: compute top-K over logits
TopK-->>PyAPI: indices returned
PyAPI-->>User: (logits, indices)
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
📝 Coding Plan
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces significant performance enhancements for the low-latency decode stage, particularly for DeepSeek v3 sparse attention. It achieves this by implementing a highly optimized TopK kernel that utilizes CTA clusters and a fused radix approach, alongside a novel sparse indexer kernel that integrates histogram computation directly into the GEMM epilogue. These changes aim to reduce computational overhead and memory access, leading to substantial speedups in critical inference operations. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces an optimized clusters TopK kernel and a fused kernel for the DeepSeekV3.2 sparse indexer, targeting low-latency decode stages. The changes include a new CUDA kernel (csrc/fast_topk_clusters.cu) and modifications to benchmark scripts to evaluate the performance of the new kernels. The review focuses on correctness and potential issues in the CUDA code and benchmark scripts.
There was a problem hiding this comment.
Actionable comments posted: 14
🧹 Nitpick comments (1)
tests/attention/test_mqa_histogram.py (1)
130-148: Please exercisepdl_enabled=Trueat least once.The fused kernel has a separate PDL completion path, but this matrix fixes
pdl_enabledto[False]. A single SM100a smoke case forTruewould keep that advertised path from regressing unnoticed.Possible change
-@pytest.mark.parametrize("pdl_enabled", [False]) +@pytest.mark.parametrize("pdl_enabled", [False, True])🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/attention/test_mqa_histogram.py` around lines 130 - 148, The test fixes pdl_enabled to [False], skipping the fused PDL completion path; update the pytest parametrization in test_mqa_topk_indexer so that pdl_enabled includes True at least once (e.g., [False, True] or add a separate test case) to exercise the PDL path; locate the test function test_mqa_topk_indexer and the call to mqa_topk_indexer (pdl_enabled=...) and change the `@pytest.mark.parametrize`("pdl_enabled", [False]) line to include True so the fused kernel's PDL completion is validated.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@benchmarks/bench_dsv3_sparse_indexer.py`:
- Around line 10-14: The usage advertises --compare-deepgemm but argparse never
defines it and compare_deepgemm is forced to HAS_DEEP_GEMM; add a CLI flag
(e.g., parser.add_argument("--compare-deepgemm", action="store_true") or
store_false as appropriate) and use args.compare_deepgemm when initializing the
compare_deepgemm variable instead of always using HAS_DEEP_GEMM, and update the
other occurrence(s) where compare_deepgemm is referenced (the main argument
parsing block and the loop that sets/uses compare_deepgemm) so the flag actually
enables/disables the DeepGEMM comparison.
In `@csrc/fast_topk_clusters.cu`:
- Around line 428-438: The extern_shared_mem calculation (extern_shared_mem)
grows with num_cached but the setup_kernel_smem_once calls for
fast_topk_clusters_kernel<N> are hard-coded to a fixed 4096*16+2048*4
shared-memory size, which can mismatch and cause kernel launch failures when
num_cached is large; fix by computing the required dynamic shared memory based
on num_cached and TopK, clamp it to the kernel's opted-in maximum, and call
setup_kernel_smem_once<fast_topk_clusters_kernel<...>>(required_smem) for each
instantiated kernel (fast_topk_clusters_kernel<1>, <2>, <4>, <8>, <16>), or add
a guard that prevents registering more than the kernel's max shared memory
budget when num_cached exceeds the supported value.
- Around line 330-336: The loop compares a full 32-bit payload from
s_cached_logit_bits to an 8-bit threshold_bin, causing missed matches in the
final refill pass; change the extraction to mask the low byte (e.g., int bin =
s_cached_logit_bits[i] & 0xFF) before comparing to threshold_bin so only the
8-bit bucket is compared; update the loop that uses s_cached_logit_bits,
cached_idx, threshold_bin and s_topk_inds accordingly to ensure remaining
candidates are correctly detected and added.
- Around line 440-471: The switch over num_clusters only covers {1,2,4,8,16} and
silently falls through for other values leaving indices unmodified; add a
default branch in the switch to reject unsupported num_clusters explicitly
(e.g., log an error or set an error return code and return/throw), referencing
the fast_topk_clusters_kernel launch sites and the indices/output buffer so
callers won't get stale results; ensure the default path clearly documents the
unsupported value (num_clusters) and exits the function consistently with the
surrounding error-handling convention.
- Around line 224-230: The code drops elements when cached_offset >= num_cached
which makes the TopK path approximate; instead of silently discarding, always
account for the item's contribution to the next-pass histogram: keep the
existing branch that writes s_cached_indices and s_cached_logit_bits when
cached_offset < num_cached, but in the else path still perform the histogram
update (atomicAdd(shared_hist[1] + ((bits >> 16) & 0xff), 1)) so discarded items
are counted; apply this same change to the other symmetric block around
s_cached_*/shared_num_cached_count (the second occurrence) so all overflow
candidates still update shared_hist even when not stored in s_cached_* arrays.
In `@csrc/mqa_v2_hist.cu`:
- Around line 509-535: The kernel reads only sm_mapping[sm_id] instead of the
per-SM tile entry written by launch_mqa_kernel_metadata, so it ignores all tiles
beyond the first num_sms; replace the read of sm_mapping in the mqa_v2_hist
kernel (the assignment to sm_map currently using sm_mapping[sm_id]) with the
correct index that consumes every tile: use sm_mapping[sm_id * sm_multiple +
sm_loc] (and keep the following __shfl_sync broadcasts), and apply the same fix
to the corresponding block around lines 579-633 where the same pattern appears;
this ensures the kernel consumes all sm_mapping entries created by
launch_mqa_kernel_metadata.
- Around line 542-560: The function init_tensormap_nd currently ends with
assert(err == CUDA_SUCCESS) which is compiled out in release builds; replace
this assert with runtime error handling that checks the return value from
cuTensorMapEncodeTiled and reports/throws on failure (e.g., call the existing
checkCu() helper used elsewhere or throw a std::runtime_error with a message
containing the cuTensorMapEncodeTiled error code and context), so the tensor-map
descriptor is not left invalid at runtime; update init_tensormap_nd to perform
this check after the cuTensorMapEncodeTiled call and propagate/throw an
informative error if err != CUDA_SUCCESS.
In `@csrc/topk_clusters_pre_hist.cu`:
- Around line 297-303: Refactor the comparison inside the refill loop to use
only the low 8 bits of s_cached_logit_bits when matching threshold_bin: in the
for loop that iterates i over buf_len (using s_cached_logit_bits,
s_cached_indices, threshold_bin, s_topk_inds, TopK, shared_final_idx_count),
mask the full 32-bit bin (currently assigned from s_cached_logit_bits[i]) with
0xFF (e.g., compute an 8-bit bin8) and compare bin8 == threshold_bin before
doing the atomicAdd and storing cached_idx so the final refill pass correctly
matches the last 8-bit radix bucket.
- Around line 217-223: The current TopK path silently drops elements when
cached_offset >= num_cached (in the block that writes
s_cached_indices/s_cached_logit_bits and atomically increments
shared_num_cached_count), which makes the path approximate; change it to detect
and record overflow: after atomicAdd(&shared_num_cached_count[0], 1) check if
cached_offset < num_cached to keep the existing cache write, otherwise set a
shared overflow flag (e.g., shared_overflow[0] = 1 using atomicOr) and ensure
the histogram still accounts for the overflowed element (increment shared_hist
directly or mark the bin as overflowed); add the same overflow handling to the
other identical block (the one around s_cached_* at the later location) and wire
a fallback path that, when shared_overflow is set, forces the algorithm to
process that bin in full (non-cached) mode in the next pass so no elements are
silently dropped.
In `@flashinfer/mqa_histogram.py`:
- Around line 233-235: The allocation for logits uses max_model_len without
validating it against actual sequence lengths, risking out-of-bounds writes when
seq_lens[b] > max_model_len; update the allocation logic in mqa_histogram.py
(where logits and indices are created) to compute the required_row_len =
max(max_model_len, seq_lens.max()) or clamp seq_lens to max_model_len and
allocate logits with required_row_len (and similarly adjust indices allocation),
and apply the same fix at the other occurrence (the block around the second
allocation at lines ~275-278) so kernel writes cannot exceed the allocated row
length.
- Around line 192-205: The public API get_mqa_metadata should be gated with the
`@backend_requirement` decorator and the module should expose/support helper
predicates is_compute_capability_supported(cc) and is_backend_supported() so
callers fail fast on unsupported devices; update get_mqa_metadata (and the other
public wrappers in the same block: the functions spanning the region around
get_mqa_metadata and the subsequent APIs between the referenced lines) to have
`@backend_requirement` applied, implement or import
is_compute_capability_supported(cc) and is_backend_supported(), and ensure the
decorator uses those helpers to check device compute capability before entering
get_mqa_histogram_module().get_mqa_metadata to prevent JIT/build/launch on
unsupported hardware.
In `@include/flashinfer/mqa_histogram/common.cuh`:
- Around line 107-130: The WLMS helper currently has an unconditional return
that prevents the warp-aggregated path from ever executing; remove that return
and restructure the function so that when active_mask is true it attempts the
warp-aggregation branch (using __ballot_sync, getLaneMaskLt(), __popc and
atomicAdd on shared_hist + val) and falls back to per-lane atomicAdd only when
the warp-aggregation condition isn't met or active_mask is false; update logic
in WLMS(uint8_t val, int *shared_hist, bool active_mask = true) to first compute
warpFlags = __ballot_sync(0xffffffff, active_mask), derive per-lane uniqueness
with the existing bit tests and getLaneMaskLt() check, and call
atomicAdd(shared_hist + val, __popc(warpFlags)) only from the first lane (bits
== 0) while preserving the previous per-lane atomicAdd behavior as a fallback.
- Around line 8-16: The CUDA_CHECK macro should not call exit(); change it to
propagate failures so callers like setup_kernel_smem_once() can handle errors:
replace the current CUDA_CHECK (used around cudaFuncSetAttribute and elsewhere)
with a variant that captures cudaError_t and returns or throws it (e.g., return
err from the enclosing function or rethrow as a cudaError_t/exception), then
update setup_kernel_smem_once() to use that non-fatal check and propagate the
error to the Python-facing caller instead of terminating the process; apply the
same change for the other occurrences mentioned (lines ~145-153) so all
cudaFuncSetAttribute and similar calls surface errors to callers.
In `@include/flashinfer/mqa_histogram/tcgen05_utils.cuh`:
- Around line 162-177: The function template tcgen05_ld_32x32b has a
non-dependent static_assert inside a discarded else branch which is ill-formed
for nvcc/C++ even if only valid WIDTHs are instantiated; move the static_assert
out of the else branch to function scope and make it dependent on the template
parameter (for example assert (WIDTH==2 || WIDTH==4 || WIDTH==8 || WIDTH==16 ||
WIDTH==32 || WIDTH==64)) so the assertion is only evaluated per-instantiation
and not unconditionally; keep the existing if constexpr chain (calling
tcgen05_ld_32x32b_x2/_x4/_x8/_x16/_x32/_x64) and replace the final
else/static_assert pattern with a single dependent static_assert after the
chain.
---
Nitpick comments:
In `@tests/attention/test_mqa_histogram.py`:
- Around line 130-148: The test fixes pdl_enabled to [False], skipping the fused
PDL completion path; update the pytest parametrization in test_mqa_topk_indexer
so that pdl_enabled includes True at least once (e.g., [False, True] or add a
separate test case) to exercise the PDL path; locate the test function
test_mqa_topk_indexer and the call to mqa_topk_indexer (pdl_enabled=...) and
change the `@pytest.mark.parametrize`("pdl_enabled", [False]) line to include True
so the fused kernel's PDL completion is validated.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: f66ed369-0139-4eea-919c-24392c5c644a
📒 Files selected for processing (15)
benchmarks/bench_dsv3_sparse_indexer.pybenchmarks/bench_topk.pycsrc/fast_topk_clusters.cucsrc/flashinfer_mqa_histogram_binding.cucsrc/mqa_metadata.cucsrc/mqa_v2.cucsrc/mqa_v2_hist.cucsrc/topk_clusters_pre_hist.cuflashinfer/__init__.pyflashinfer/aot.pyflashinfer/jit/mqa_histogram.pyflashinfer/mqa_histogram.pyinclude/flashinfer/mqa_histogram/common.cuhinclude/flashinfer/mqa_histogram/tcgen05_utils.cuhtests/attention/test_mqa_histogram.py
|
/bot run |
|
@Aalanli what compute capability did you optimize for / on what gpu did you collect these benchmarking results? |
There was a problem hiding this comment.
Actionable comments posted: 4
♻️ Duplicate comments (7)
csrc/mqa_v2_hist.cu (2)
145-146:⚠️ Potential issue | 🟡 MinorFix typo in register limit constant.
Same issue as in
csrc/mqa_v2.cu: uses65636instead of the correct65536.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/mqa_v2_hist.cu` around lines 145 - 146, The static_assert that checks register usage uses the incorrect constant 65636; update the assertion in the static_assert statement that references num_epilogue_threads, math_nreg, num_special_warps, and special_nreg to use the correct hardware limit 65536 instead of 65636 so the check reads "... <= 65536" and preserves the existing error message "too many registers".
503-509:⚠️ Potential issue | 🟠 MajorReplace
assertwith runtime error handling forcuTensorMapEncodeTiled.Same issue as in
csrc/mqa_v2.cu: the assertion compiles out withNDEBUG, leaving invalid tensor-map descriptors undetected.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/mqa_v2_hist.cu` around lines 503 - 509, The call to cuTensorMapEncodeTiled stores its return in err but uses assert(err == CUDA_SUCCESS), which can be compiled out; replace the assert with explicit runtime error handling: check if err != CUDA_SUCCESS after cuTensorMapEncodeTiled (the same pattern used in csrc/mqa_v2.cu), log or format a clear error message including the numeric err and a short context string (e.g., "cuTensorMapEncodeTiled failed"), and then abort/throw/return an error code as appropriate for this module to prevent proceeding with an invalid tensor-map descriptor; update the code around the err variable and the cuTensorMapEncodeTiled call accordingly.include/flashinfer/mqa_histogram/tcgen05_utils.cuh (1)
152-169:⚠️ Potential issue | 🟠 MajorMove
static_assertout of the discardedelsebranch.
static_assert(false, ...)inside theelsebranch ofif constexpris non-dependent on the template parameterWIDTH, making this ill-formed per C++ standard ([temp.res]/8). nvcc will fail to compile even when only valid WIDTHs are instantiated. Place a dependent assertion at function scope instead.🛠️ Proposed fix
template <int WIDTH> __device__ inline void tcgen05_ld_32x32b(int addr, float (&tmp)[WIDTH]) { + static_assert(WIDTH == 2 || WIDTH == 4 || WIDTH == 8 || + WIDTH == 16 || WIDTH == 32 || WIDTH == 64, + "WIDTH must be 2, 4, 8, 16, 32, or 64"); if constexpr (WIDTH == 2) { tcgen05_ld_32x32b_x2(addr, tmp); } else if constexpr (WIDTH == 4) { tcgen05_ld_32x32b_x4(addr, tmp); } else if constexpr (WIDTH == 8) { tcgen05_ld_32x32b_x8(addr, tmp); } else if constexpr (WIDTH == 16) { tcgen05_ld_32x32b_x16(addr, tmp); } else if constexpr (WIDTH == 32) { tcgen05_ld_32x32b_x32(addr, tmp); } else if constexpr (WIDTH == 64) { tcgen05_ld_32x32b_x64(addr, tmp); - } else { - static_assert(false, "WIDTH must be 2, 4, 8, 16, 32, or 64"); } }🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/mqa_histogram/tcgen05_utils.cuh` around lines 152 - 169, The static_assert inside the discarded else branch of the template function tcgen05_ld_32x32b is non-dependent on WIDTH and can cause compilation failures; move the check out of the discarded branch and make it dependent on the template parameter (e.g., use a helper like always_false<WIDTH> or similar) and place a single static_assert(always_false<WIDTH>, "WIDTH must be 2, 4, 8, 16, 32, or 64") at function scope after the if constexpr chain so only invalid WIDTH instantiations trigger the assertion; ensure the helper template (always_false) depends on WIDTH so the assertion is SFINAE-friendly.flashinfer/mqa_histogram.py (2)
223-240:⚠️ Potential issue | 🟠 MajorGate public APIs behind
@backend_requirementfor SM100a.These APIs require SM100a (Blackwell) but don't use the
@backend_requirementdecorator. Callers will encounter cryptic JIT or launch failures on unsupported GPUs. Add the decorator to fail fast with a clear error message.As per coding guidelines:
flashinfer/*.py: Use@backend_requirementdecorator on APIs that have compute capability requirements and provideis_compute_capability_supported(cc)andis_backend_supported()methods.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/mqa_histogram.py` around lines 223 - 240, Add a backend guard to the public API by decorating get_mqa_metadata with `@backend_requirement`, import the decorator and supply the module-level support query functions (is_compute_capability_supported and is_backend_supported) so callers fail fast on non-SM100a GPUs; specifically, add the `@backend_requirement`(is_compute_capability_supported, is_backend_supported) decorator above the get_mqa_metadata function and ensure those two functions are defined/accessibly imported in the same module.
268-272:⚠️ Potential issue | 🔴 CriticalValidate that
max_model_lenaccommodates the largest sequence.The
logitsbuffer is allocated with shape[batch_size, max_model_len], but kernels write up toseq_lens[b]elements per row. If any sequence exceedsmax_model_len, this causes out-of-bounds writes.🛡️ Proposed validation
+ max_seq = int(seq_lens.max().item()) + if max_seq > max_model_len: + raise ValueError( + f"max_model_len={max_model_len} is smaller than the longest sequence {max_seq}" + ) logits = torch.empty( batch_size, max_model_len, device=q.device, dtype=torch.float32 )Also applies to
mqa_topk_indexerat lines 313-316.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/mqa_histogram.py` around lines 268 - 272, The logits buffer (logits) and related index buffer (indices) are allocated using max_model_len but kernels write up to seq_lens[b], so validate that max_model_len >= max(seq_lens) before allocation (or reallocate using max_seq_len derived from seq_lens); update both the block that creates logits/indices in mqa_histogram (the variables logits, indices, q, max_model_len) and the analogous allocation in mqa_topk_indexer to compute max_seq_len = seq_lens.max().item() (or raise a clear error) and use that value for the second dimension to prevent out‑of‑bounds writes. Ensure the validation uses the existing seq_lens tensor/variable so no silent truncation occurs.csrc/fast_topk_clusters.cu (1)
397-407:⚠️ Potential issue | 🟠 MajorGuard against
num_cachedexceeding the registered shared-memory budget.
extern_shared_memscales withnum_cached, butsetup_kernel_smem_onceis hardcoded to the 4096-candidate footprint (4096 * 16 + 2048 * 4). Largernum_cachedvalues will request more dynamic shared memory than configured, causing launch failures.🛡️ Proposed guard
int extern_shared_mem = (num_cached * 2 * sizeof(float) + num_cached * 2 * sizeof(int) + TopK * sizeof(int)); // 2 * num_cached float, 2 * num_cached int, topk int + + constexpr int max_smem_budget = 4096 * 16 + 2048 * 4; + assert(extern_shared_mem <= max_smem_budget && + "num_cached exceeds configured shared-memory budget (max 4096)");🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/fast_topk_clusters.cu` around lines 397 - 407, The code computes extern_shared_mem from num_cached but always registers a fixed shared-memory footprint using setup_kernel_smem_once for fast_topk_clusters_kernel<N>; guard against num_cached exceeding that registered budget by computing the required dynamic shared memory (extern_shared_mem) and ensuring you only call setup_kernel_smem_once/setup_non_portable_clusters_once with a footprint >= extern_shared_mem (or clamp num_cached to the max supported value before computing extern_shared_mem), and if num_cached is larger emit a clear error/return; update references around extern_shared_mem, num_cached, setup_kernel_smem_once, setup_non_portable_clusters_once and fast_topk_clusters_kernel so the registered smem matches the runtime requirement or the code refuses unsupported sizes.benchmarks/bench_dsv3_sparse_indexer.py (1)
10-14:⚠️ Potential issue | 🟡 MinorWire the documented
--compare-deepgemmflag.The usage text advertises
--compare-deepgemmbut argparse never defines it. The benchmark always usescompare_deepgemm=HAS_DEEP_GEMMregardless of user intent.🔧 Proposed fix
parser.add_argument( + "--compare-deepgemm", + action="store_true", + help="Benchmark deep_gemm fp8_paged_mqa_logits + flashinfer top_k", + ) + parser.add_argument( "--pdl", action="store_true",And update line 273:
- compare_deepgemm=HAS_DEEP_GEMM, + compare_deepgemm=args.compare_deepgemm and HAS_DEEP_GEMM,🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_dsv3_sparse_indexer.py` around lines 10 - 14, Argparse isn’t exposing the documented --compare-deepgemm flag, so the script always uses compare_deepgemm = HAS_DEEP_GEMM; add a boolean flag to the ArgumentParser (e.g., parser.add_argument("--compare-deepgemm", action="store_true", help="Compare against DeepGEMM")) and then use the parsed value (args.compare_deepgemm) instead of defaulting to HAS_DEEP_GEMM when setting compare_deepgemm (and update the assignment near where compare_deepgemm is set, referenced as compare_deepgemm and HAS_DEEP_GEMM).
🧹 Nitpick comments (2)
tests/attention/test_mqa_histogram.py (2)
94-96: Prefix unused variable with underscore.
num_headsis unpacked but never used. Prefix it with_to indicate intentional discard.- batch_size, num_heads, head_dim = q_fp8.shape + batch_size, _num_heads, head_dim = q_fp8.shape🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/attention/test_mqa_histogram.py` around lines 94 - 96, In _dsa_topk_indexer change the unused unpacked variable `num_heads` to a discarded name (e.g., `_num_heads` or `_`) so the tuple unpacking reads `batch_size, _num_heads, head_dim = q_fp8.shape`, indicating the value is intentionally unused and avoiding linter warnings; update only the variable name in that assignment inside the _dsa_topk_indexer function.
159-165: Rename ambiguous variablel.The variable name
lis visually ambiguous (looks like digit1). Rename toseq_len_ior similar for clarity.- l = int(seq_lens[i]) + seq_len_i = int(seq_lens[i]) logit_ref = logits_ref[i] - logit_0 = logits_0[i][:l] - logit_1 = logits_1[i][:l] + logit_0 = logits_0[i][:seq_len_i] + logit_1 = logits_1[i][:seq_len_i]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/attention/test_mqa_histogram.py` around lines 159 - 165, Rename the ambiguous loop variable `l` to a clearer name like `seq_len_i` in the loop that iterates over range(B) where `prefix`, `seq_lens`, `logits_ref`, `logits_0`, and `logits_1` are used; update its uses (currently `l = int(seq_lens[i])` and slicing `logit_0 = logits_0[i][:l]`, `logit_1 = logits_1[i][:l]`) to `seq_len_i` so the intent is clear and avoids confusion with the digit "1".
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@csrc/fast_topk_clusters.cu`:
- Around line 307-322: In the final refill pass (the t == 3 branch),
s_cached_logit_bits holds the full 32-bit payload but threshold_bin is an 8-bit
bucket value, so change the comparison to mask the low byte before comparing:
when reading int bin = s_cached_logit_bits[i] inside the loop, mask it (bin &
0xFF) or otherwise extract the lowest byte and compare that to threshold_bin;
keep the existing atomicAdd into shared_final_idx_count and the TopK-bound write
into s_topk_inds unchanged so the refill will correctly match and collect
candidates.
In `@csrc/flashinfer_mqa_histogram_binding.cu`:
- Around line 255-270: fast_topk_clusters currently casts num_clusters to int
and calls launch_fast_topk_clusters without validating it; add a validation step
in fast_topk_clusters to ensure num_clusters is one of the supported values
{1,2,4,8,16} before calling launch_fast_topk_clusters (use the num_clusters
variable and the launch_fast_topk_clusters call site to locate the change) and
if it is not supported, return/throw a clear error (or set indices to a safe
default) to avoid silently skipping the kernel and leaving indices
uninitialized; keep all other parameter checks (check_logits, check_indices,
check_seq_lens) intact.
In `@csrc/mqa_v2.cu`:
- Around line 141-142: The static_assert that checks register usage (involving
num_epilogue_threads, math_nreg, num_special_warps, and special_nreg) uses the
wrong constant 65636; change that literal to the correct GPU register limit
65536 so the assertion reads that combined registers <= 65536, keeping the same
message "too many registers".
- Around line 479-486: The assert(err == CUDA_SUCCESS) after
cuTensorMapEncodeTiled is unsafe because it can compile out; replace it with
proper runtime error handling: check the returned err from
cuTensorMapEncodeTiled and call the existing helper (e.g., checkCu(err,
"cuTensorMapEncodeTiled") used in csrc/xqa/tensorMap.cpp) or throw a
std::runtime_error with a clear message including the cuError name/number and
context (e.g., which tmap/ptr operation failed); ensure any needed cleanup or
early return follows the error path so an invalid tensor-map descriptor is never
used.
---
Duplicate comments:
In `@benchmarks/bench_dsv3_sparse_indexer.py`:
- Around line 10-14: Argparse isn’t exposing the documented --compare-deepgemm
flag, so the script always uses compare_deepgemm = HAS_DEEP_GEMM; add a boolean
flag to the ArgumentParser (e.g., parser.add_argument("--compare-deepgemm",
action="store_true", help="Compare against DeepGEMM")) and then use the parsed
value (args.compare_deepgemm) instead of defaulting to HAS_DEEP_GEMM when
setting compare_deepgemm (and update the assignment near where compare_deepgemm
is set, referenced as compare_deepgemm and HAS_DEEP_GEMM).
In `@csrc/fast_topk_clusters.cu`:
- Around line 397-407: The code computes extern_shared_mem from num_cached but
always registers a fixed shared-memory footprint using setup_kernel_smem_once
for fast_topk_clusters_kernel<N>; guard against num_cached exceeding that
registered budget by computing the required dynamic shared memory
(extern_shared_mem) and ensuring you only call
setup_kernel_smem_once/setup_non_portable_clusters_once with a footprint >=
extern_shared_mem (or clamp num_cached to the max supported value before
computing extern_shared_mem), and if num_cached is larger emit a clear
error/return; update references around extern_shared_mem, num_cached,
setup_kernel_smem_once, setup_non_portable_clusters_once and
fast_topk_clusters_kernel so the registered smem matches the runtime requirement
or the code refuses unsupported sizes.
In `@csrc/mqa_v2_hist.cu`:
- Around line 145-146: The static_assert that checks register usage uses the
incorrect constant 65636; update the assertion in the static_assert statement
that references num_epilogue_threads, math_nreg, num_special_warps, and
special_nreg to use the correct hardware limit 65536 instead of 65636 so the
check reads "... <= 65536" and preserves the existing error message "too many
registers".
- Around line 503-509: The call to cuTensorMapEncodeTiled stores its return in
err but uses assert(err == CUDA_SUCCESS), which can be compiled out; replace the
assert with explicit runtime error handling: check if err != CUDA_SUCCESS after
cuTensorMapEncodeTiled (the same pattern used in csrc/mqa_v2.cu), log or format
a clear error message including the numeric err and a short context string
(e.g., "cuTensorMapEncodeTiled failed"), and then abort/throw/return an error
code as appropriate for this module to prevent proceeding with an invalid
tensor-map descriptor; update the code around the err variable and the
cuTensorMapEncodeTiled call accordingly.
In `@flashinfer/mqa_histogram.py`:
- Around line 223-240: Add a backend guard to the public API by decorating
get_mqa_metadata with `@backend_requirement`, import the decorator and supply the
module-level support query functions (is_compute_capability_supported and
is_backend_supported) so callers fail fast on non-SM100a GPUs; specifically, add
the `@backend_requirement`(is_compute_capability_supported, is_backend_supported)
decorator above the get_mqa_metadata function and ensure those two functions are
defined/accessibly imported in the same module.
- Around line 268-272: The logits buffer (logits) and related index buffer
(indices) are allocated using max_model_len but kernels write up to seq_lens[b],
so validate that max_model_len >= max(seq_lens) before allocation (or reallocate
using max_seq_len derived from seq_lens); update both the block that creates
logits/indices in mqa_histogram (the variables logits, indices, q,
max_model_len) and the analogous allocation in mqa_topk_indexer to compute
max_seq_len = seq_lens.max().item() (or raise a clear error) and use that value
for the second dimension to prevent out‑of‑bounds writes. Ensure the validation
uses the existing seq_lens tensor/variable so no silent truncation occurs.
In `@include/flashinfer/mqa_histogram/tcgen05_utils.cuh`:
- Around line 152-169: The static_assert inside the discarded else branch of the
template function tcgen05_ld_32x32b is non-dependent on WIDTH and can cause
compilation failures; move the check out of the discarded branch and make it
dependent on the template parameter (e.g., use a helper like always_false<WIDTH>
or similar) and place a single static_assert(always_false<WIDTH>, "WIDTH must be
2, 4, 8, 16, 32, or 64") at function scope after the if constexpr chain so only
invalid WIDTH instantiations trigger the assertion; ensure the helper template
(always_false) depends on WIDTH so the assertion is SFINAE-friendly.
---
Nitpick comments:
In `@tests/attention/test_mqa_histogram.py`:
- Around line 94-96: In _dsa_topk_indexer change the unused unpacked variable
`num_heads` to a discarded name (e.g., `_num_heads` or `_`) so the tuple
unpacking reads `batch_size, _num_heads, head_dim = q_fp8.shape`, indicating the
value is intentionally unused and avoiding linter warnings; update only the
variable name in that assignment inside the _dsa_topk_indexer function.
- Around line 159-165: Rename the ambiguous loop variable `l` to a clearer name
like `seq_len_i` in the loop that iterates over range(B) where `prefix`,
`seq_lens`, `logits_ref`, `logits_0`, and `logits_1` are used; update its uses
(currently `l = int(seq_lens[i])` and slicing `logit_0 = logits_0[i][:l]`,
`logit_1 = logits_1[i][:l]`) to `seq_len_i` so the intent is clear and avoids
confusion with the digit "1".
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 2f34ad89-71f6-4b38-9132-a193346b8b1b
📒 Files selected for processing (12)
benchmarks/bench_dsv3_sparse_indexer.pybenchmarks/bench_topk.pycsrc/fast_topk_clusters.cucsrc/flashinfer_mqa_histogram_binding.cucsrc/mqa_metadata.cucsrc/mqa_v2.cucsrc/mqa_v2_hist.cucsrc/topk_clusters_pre_hist.cuflashinfer/mqa_histogram.pyinclude/flashinfer/mqa_histogram/common.cuhinclude/flashinfer/mqa_histogram/tcgen05_utils.cuhtests/attention/test_mqa_histogram.py
🚧 Files skipped from review as they are similar to previous changes (4)
- csrc/mqa_metadata.cu
- csrc/topk_clusters_pre_hist.cu
- benchmarks/bench_topk.py
- include/flashinfer/mqa_histogram/common.cuh
|
[FAILED] Pipeline #46461374: 11/20 passed |
|
Hi @kahyunnam , I benchmarked on a bare-metal B200 node, for sm100a. Note that the topK works for both hopper and blackwell but the gemm kernel uses tcgen05 instructions so only works for blackwell. Most of the speedup I observed comes from the topK, but the histogram fusion and PDL is a slightly added boost. This PR is mostly to get the comparisons going, the code still needs a lot of cleanup if we want to merge this. |
This PR specifically targets low-latency decode stage, where number of sms are abundant with respect to the problem size.
The key optimization for topK is making use of CTA clusters on top of a fused radix algorithm. For the sparse indexer another optimization is added where we fuse the computation of the histogram of the first bin to the epilogue of the gemm kernel. This eliminates a pass of reading over logits, so now we only read one row of logits in total, the rest is cached in shared memory for later stages. The optimization does not impact the gemm kernel as it is moved out of the critical path of the epilogue, while providing some speedups for the subsequent topK.
Comparison against flashinfer's topK
Note there is a minor difference between the two ops as flashinfer's existing topK also writes the logit values in addition to the index, but this part should be negligible.
For high batch size the algorithm would need to be adjusted as increasing row-parallelism via clusters no longer makes sense. Additionally, the sm scheduling portion for the gemm kernel would need adjustment. For low-batch, every sm processes one q block but for higher batch it makes sense to use a sm persistent schedule that lets every sm process multiple q blocks, which may introduce overhead for low batch size.
Comparison against deep gemm + flashinfer topK
This is the entire indexer where the deepgemmm_ref uses deepgemm's
fp8_paged_mqa_logitswith flashinfer's topK.Comparison against VLLM's indexer
Individual Operator comparison
Epilogue fusion does not affect gemm performance
TopK gets speedup
Summary by CodeRabbit
New Features
Tests