Skip to content

Chore: Cute dsl moe update (TMA.RED implementation)#2529

Open
nv-yunzheq wants to merge 3 commits intoflashinfer-ai:mainfrom
nv-yunzheq:cuteDSL_moe_update
Open

Chore: Cute dsl moe update (TMA.RED implementation)#2529
nv-yunzheq wants to merge 3 commits intoflashinfer-ai:mainfrom
nv-yunzheq:cuteDSL_moe_update

Conversation

@nv-yunzheq
Copy link
Collaborator

@nv-yunzheq nv-yunzheq commented Feb 9, 2026

📌 Description

The PR is follow up to PR #2398
To integration TRTLLM PR 10987. Use TMA.RED to improve effective memory bandwidth

Perf data is (tested on GB200):

Tokens CuteDSL (main) ms CuteDSL (TMA.RED) ms TRTLLM gen ms CUTLASS ms Winner CuteDSL Speedup (main/TMA.RED)
1 0.064 0.064 0.053 0.099 TRTLLM 1.000x
2 0.077 0.077 0.063 0.107 TRTLLM 1.000x
4 0.096 0.096 0.085 0.131 TRTLLM 1.000x
8 0.096 0.096 0.091 0.131 TRTLLM 1.000x
16 0.101 0.102 0.103 0.138 CuteDSL 0.990x
32 0.114 0.114 0.142 0.152 CuteDSL 1.000x
62 0.122 0.122 0.183 0.163 CuteDSL 1.000x
128 0.133 0.132 0.173 0.220 CuteDSL 1.008x
256 0.142 0.138 0.220 0.251 CuteDSL 1.029x
512 0.190 0.183 0.271 0.333 CuteDSL 1.038x
1024 0.286 0.278 0.576 0.482 CuteDSL 1.029x
2048 0.472 0.461 0.555 0.723 CuteDSL 1.024x
4096 0.855 0.824 0.873 1.278 CuteDSL 1.038x
8192 1.764 1.713 1.653 2.383 TRTLLM 1.030x

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

Release Notes

  • New Features

    • Introduced block-reduction optimization in MOE finalization kernels for improved performance on latest hardware.
    • Added support for block-wise reduction operations across multiple data types (BF16, FP32, FP16).
  • Performance

    • Optimized GPU memory utilization by reducing unnecessary cross-device data transfers during computation.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 9, 2026

📝 Walkthrough

Walkthrough

Introduces block-reduction optimization in Blackwell GEMM finalize kernel via new use_blkred flag enabling Tensor Memory block-reduction in epilogue. Adds DSL operations supporting BF16/FP16/FP32 reductions. Refactors compute reference test to maintain data on GPU, eliminating CPU transfers.

Changes

Cohort / File(s) Summary
Blackwell Block-Reduction Kernel
flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py
Added use_blkred flag to enable block-reduction in epilogue. Introduced block-reduction atomic operations, modified shared memory layout generation with 16B alignment padding, updated stage accounting to reserve C smem bytes, added sC field to SharedStorage, implemented epilog_smem_copy_and_partition method, and extended kernel interface to accept c_smem_layout_staged and topK parameters.
Block-Reduction DSL Operations
flashinfer/fused_moe/cute_dsl/blackwell/utils.py
Added three new DSL user-op functions for block-wise reductions: blk_reduce_bf16, blk_reduce_fp32, blk_reduce_fp16, each using corresponding inline assembly for atomic accumulation across data types.
Frontend Configuration
flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py
Updated kernel instantiation in _get_compiled_finalize_kernel to pass use_blkred=True parameter when constructing Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel.
Compute Reference Optimization
tests/moe/test_cute_dsl_fused_moe.py
Refactored compute_reference_moe_fp4 to maintain tensor data on GPU throughout, eliminating cross-device transfers to CPU for intermediate computations and returning device-resident output directly.

Sequence Diagram

sequenceDiagram
    participant Kernel as Kernel Launch
    participant Epilogue as Epilogue Handler
    participant RegMem as Register Memory
    participant SmemCopy as Smem Copy/Partition
    participant BlockRed as Block-Reduce Op
    participant GemmMem as GEMM Memory

    Kernel->>Epilogue: Invoke epilogue with tAcc data
    alt use_blkred enabled
        Epilogue->>RegMem: Load tTR_rC from registers
        Epilogue->>SmemCopy: Call epilog_smem_copy_and_partition
        SmemCopy->>SmemCopy: Partition R2S tiled copy
        SmemCopy-->>Epilogue: Return partitioned copy & tensors
        Epilogue->>SmemCopy: Store tTR_rC to sC (shared memory)
        Epilogue->>Epilogue: Fence & synchronize
        SmemCopy->>BlockRed: Prepare block-reduce input from sC
        BlockRed->>GemmMem: Execute reduce.async.bulk (bf16/fp32/fp16)
        GemmMem-->>BlockRed: Accumulate to destination
    else use_blkred disabled
        Epilogue->>Epilogue: Execute per-subtile epilogue path (original)
    end
    Epilogue-->>Kernel: Return with finalized output
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Suggested labels

v0.6.3

Suggested reviewers

  • djmmoss
  • yzh119
  • cyx-6
  • jiahanc

Poem

🐰 A block reduction hops with glee,
Through Blackwell's swift epilogue spree,
BF16 and FP32 aligned in sync,
GPU tensors never CPU wink,
Optimization's dance—thump thump—with flair! 🎭

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 73.33% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title 'Chore: Cute dsl moe update (TMA.RED implementation)' accurately describes the main change: implementation of TMA.RED optimization in the Cute DSL MOE kernel code.
Description check ✅ Passed The description includes reference to related PR #2398, details about the TRTLLM integration, comprehensive performance data, but the pre-commit and tests checklist items remain unchecked.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @nv-yunzheq, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the performance of the Cute DSL MoE kernel on Blackwell architecture by implementing TMA.RED. The core change involves introducing a block reduction mechanism that optimizes memory bandwidth during the epilogue phase of GEMM operations. This leads to measurable speedups, especially for larger token sizes, and refines the shared memory management within the kernel. Additionally, the testing infrastructure has been improved by shifting reference computations to the GPU.

Highlights

  • TMA.RED Integration: Integrated TMA.RED (Tensor Memory Accelerator Reduce) functionality into the Cute DSL MoE kernel for Blackwell architecture, aiming to improve effective memory bandwidth.
  • Performance Optimization: Introduced a use_blkred flag to enable a new block reduction strategy, which shows performance improvements for larger token counts (e.g., 1.038x speedup for 4096 tokens) on GB200.
  • Shared Memory Epilogue: Modified the GEMM kernel's epilogue stage to leverage shared memory for output accumulation when block reduction is enabled, optimizing data movement.
  • GPU-based Reference Computation: Updated the reference MoE computation in tests to execute directly on the GPU, removing CPU-specific tensor transfers for consistency and efficiency.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py
    • Imported new block reduction utility functions (blk_reduce_bf16, blk_reduce_fp16, blk_reduce_fp32).
    • Added use_blkred parameter to the kernel configuration and constructor.
    • Modified shared memory layout for the output (C) tensor to accommodate block reduction.
    • Updated kernel arguments to include c_smem_layout_staged and topK.
    • Implemented conditional logic within the kernel to utilize block reduction for epilogue accumulation when use_blkred is true.
    • Adjusted the _compute_stages function to account for shared memory usage by the block reduction mechanism and changed num_c_stage default from 2 to 1.
    • Added epilog_smem_copy_and_partition method to create tiled copies for register to shared memory (R2S).
  • flashinfer/fused_moe/cute_dsl/blackwell/utils.py
    • Added blk_reduce_bf16, blk_reduce_fp32, and blk_reduce_fp16 utility functions using LLVM inline assembly for asynchronous bulk global to shared memory reduction.
  • flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py
    • Enabled the use_blkred flag by setting it to True when compiling the fused GEMM kernel.
  • tests/moe/test_cute_dsl_fused_moe.py
    • Updated compute_reference_moe_fp4 to perform reference computations directly on the GPU, removing previous CPU-based operations and tensor transfers.
    • Removed CPU-specific tensor conversions for hidden_states, gemm1_weights, gemm2_weights, token_selected_experts, and token_final_scales.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@nv-yunzheq
Copy link
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !303 has been created, and the CI pipeline #43643225 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces TMA.RED (Tensor Memory Access with Reduction) to improve performance in the CuteDSL MoE kernel for Blackwell GPUs. The changes are quite extensive, adding a new code path for block reduction in the epilogue, modifying shared memory layouts, and updating staging logic. The tests have also been improved to run the reference implementation on the GPU, which is a welcome change. I've identified a couple of potential issues, one of which is critical, that should be addressed to ensure correctness.

Comment on lines +347 to +361
@dsl_user_op
def blk_reduce_fp16(dst_gemm, src_smem, size, loc=None, ip=None):
llvm.inline_asm(
None,
[
dst_gemm.iterator.llvm_ptr,
src_smem.iterator.llvm_ptr,
size.ir_value(),
],
"cp.reduce.async.bulk.global.shared::cta.bulk_group.noftz.f16 [$0], [$1], $2;",
"l,l,r",
has_side_effects=True,
loc=loc,
ip=ip,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The inline assembly for blk_reduce_fp16 appears to be missing the .add operation specifier. The instruction cp.reduce.async.bulk.global.shared::cta.bulk_group.noftz.f16 is likely incorrect. Based on the patterns for bf16 and fp32 and PTX documentation, it should be cp.reduce.async.bulk.global.shared::cta.bulk_group.add.noftz.f16. Without the .add, the reduction operation is not specified and will likely lead to incorrect results for fp16 data types.

Suggested change
@dsl_user_op
def blk_reduce_fp16(dst_gemm, src_smem, size, loc=None, ip=None):
llvm.inline_asm(
None,
[
dst_gemm.iterator.llvm_ptr,
src_smem.iterator.llvm_ptr,
size.ir_value(),
],
"cp.reduce.async.bulk.global.shared::cta.bulk_group.noftz.f16 [$0], [$1], $2;",
"l,l,r",
has_side_effects=True,
loc=loc,
ip=ip,
)
@dsl_user_op
def blk_reduce_fp16(dst_gemm, src_smem, size, loc=None, ip=None):
llvm.inline_asm(
None,
[
dst_gemm.iterator.llvm_ptr,
src_smem.iterator.llvm_ptr,
size.ir_value(),
],
"cp.reduce.async.bulk.global.shared::cta.bulk_group.add.noftz.f16 [$0], [$1], $2;",
"l,l,r",
has_side_effects=True,
loc=loc,
ip=ip,
)

Comment on lines +592 to 600
swizzled_pad = 16 // (self.out_dtype.width // 8)
self.c_smem_layout_staged = cute.make_layout(
(self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[1], self.num_c_stage),
stride=(
self.cta_tile_shape_mnk[1] + swizzled_pad,
1,
self.cta_tile_shape_mnk[0] * (self.cta_tile_shape_mnk[1] + 8),
),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The stride calculation for the stage dimension of c_smem_layout_staged appears to use a hardcoded padding of 8, while the padding for the M dimension correctly uses swizzled_pad. This can lead to incorrect memory access for data types where swizzled_pad is not 8 (e.g., float32 where it would be 4). For consistency and correctness, swizzled_pad should be used in both places.

Suggested change
swizzled_pad = 16 // (self.out_dtype.width // 8)
self.c_smem_layout_staged = cute.make_layout(
(self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[1], self.num_c_stage),
stride=(
self.cta_tile_shape_mnk[1] + swizzled_pad,
1,
self.cta_tile_shape_mnk[0] * (self.cta_tile_shape_mnk[1] + 8),
),
)
swizzled_pad = 16 // (self.out_dtype.width // 8)
self.c_smem_layout_staged = cute.make_layout(
(self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[1], self.num_c_stage),
stride=(
self.cta_tile_shape_mnk[1] + swizzled_pad,
1,
self.cta_tile_shape_mnk[0] * (self.cta_tile_shape_mnk[1] + swizzled_pad),
),
)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Fix all issues with AI agents
In
`@flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py`:
- Around line 592-600: The stage stride uses a hardcoded 8 instead of the
computed swizzled_pad leading to inconsistent layout sizes; update the third
element of the stride tuple in c_smem_layout_staged (created in make_layout) to
use (self.cta_tile_shape_mnk[1] + swizzled_pad) instead of
(self.cta_tile_shape_mnk[1] + 8) so the stage stride calculation is consistent
with swizzled_pad (referencing swizzled_pad, c_smem_layout_staged,
cta_tile_shape_mnk, num_c_stage).
- Around line 1971-2018: The copy into shared tensor tRS_sC uses tRS_sC[(None,
None, real_subtile_idx, None)] but sC was allocated with a size-1 third mode
(after partition_D) so real_subtile_idx can go out of bounds; fix by making the
shared buffer's third dimension match subtile_cnt or by removing the subtile
index and storing subtiles sequentially. Concretely, either change sC allocation
(the tensor named sC used before partition_D) to have shape (..., subtile_cnt,
...) so partition_D yields tRS_sC with a third mode of length subtile_cnt, or
replace the indexed copy in the block under use_blkred (the tiled_copy_r2s ->
tRS_sC copy that uses real_subtile_idx) with code that writes each subtile
sequentially (store into tRS_sC without indexing by real_subtile_idx, e.g.,
advance the destination slice per subtile) so accesses never exceed the existing
size-1 third dimension.

In `@flashinfer/fused_moe/cute_dsl/blackwell/utils.py`:
- Around line 347-361: The PTX inline-asm string in blk_reduce_fp16 is missing
the ".add." reduction opcode; update the assembly template in blk_reduce_fp16
(the llvm.inline_asm call) to include the add specifier so it matches the other
reducers (e.g., use
"cp.reduce.async.bulk.global.shared::cta.bulk_group.add.noftz.f16" or the same
pattern used by blk_reduce_bf16/blk_reduce_fp32), leaving the argument list,
constraint string, and call site (dst_gemm.iterator.llvm_ptr,
src_smem.iterator.llvm_ptr, size.ir_value(), loc, ip) unchanged.
🧹 Nitpick comments (2)
flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py (1)

206-214: use_blkred=True is hardcoded but not included in the cache key.

Since use_blkred is always True here, this is currently fine. However, if this is ever made configurable (e.g., exposed as a parameter to the public API or toggled based on heuristics), the cache on line 204 would serve stale kernels compiled with a different use_blkred value.

Consider either adding use_blkred to the cache key now, or leaving a comment noting the constraint.

tests/moe/test_cute_dsl_fused_moe.py (1)

38-57: Use flashinfer.utils.get_compute_capability instead of torch.cuda.get_device_properties.

The coding guidelines require test files under tests/ to use flashinfer.utils functions for GPU architecture checks. The is_sm100_family() function reimplements this check using raw PyTorch APIs. The same get_compute_capability utility is already used in production code (e.g., blockscaled_contiguous_grouped_gemm_finalize_fusion.py line 368).

♻️ Proposed fix
+from flashinfer.utils import get_compute_capability
+
+
 def is_sm100_family():
     """Check for SM100 family (Blackwell: SM100, SM103, SM110)."""
     if not torch.cuda.is_available():
         return False
-    props = torch.cuda.get_device_properties(0)
-    return props.major == 10
+    major, _ = get_compute_capability(torch.device("cuda:0"))
+    return major == 10

As per coding guidelines, tests/**/*.py: "Test files must use flashinfer.utils functions (get_compute_capability, is_sm90a_supported, etc.) to skip tests on unsupported GPU architectures".

Comment on lines +592 to 600
swizzled_pad = 16 // (self.out_dtype.width // 8)
self.c_smem_layout_staged = cute.make_layout(
(self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[1], self.num_c_stage),
stride=(
self.cta_tile_shape_mnk[1] + swizzled_pad,
1,
self.cta_tile_shape_mnk[0] * (self.cta_tile_shape_mnk[1] + 8),
),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Hardcoded 8 in stage stride is inconsistent with swizzled_pad.

Line 596 computes the row stride as cta_tile_shape_mnk[1] + swizzled_pad, but line 598 uses cta_tile_shape_mnk[1] + 8 for the stage stride. For Float32 output (swizzled_pad = 4), the stage stride would be larger than actually needed, though currently harmless since num_c_stage = 1.

🔧 Suggested fix for consistency
         stride=(
             self.cta_tile_shape_mnk[1] + swizzled_pad,
             1,
-            self.cta_tile_shape_mnk[0] * (self.cta_tile_shape_mnk[1] + 8),
+            self.cta_tile_shape_mnk[0] * (self.cta_tile_shape_mnk[1] + swizzled_pad),
         ),
🤖 Prompt for AI Agents
In
`@flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py`
around lines 592 - 600, The stage stride uses a hardcoded 8 instead of the
computed swizzled_pad leading to inconsistent layout sizes; update the third
element of the stride tuple in c_smem_layout_staged (created in make_layout) to
use (self.cta_tile_shape_mnk[1] + swizzled_pad) instead of
(self.cta_tile_shape_mnk[1] + 8) so the stage stride calculation is consistent
with swizzled_pad (referencing swizzled_pad, c_smem_layout_staged,
cta_tile_shape_mnk, num_c_stage).

Comment on lines +1971 to +2018
if cutlass.const_expr(self.use_blkred):
tRS_rC.store(acc_vec_final.to(self.out_dtype))
if is_valid_row:
cute.copy(
tiled_copy_r2s,
tRS_rC,
tRS_sC[(None, None, real_subtile_idx, None)],
)
else:
tTR_rC.store(acc_vec_final.to(self.out_dtype))
if is_valid_row:
rOut_epi = cute.make_tensor(tTR_rC.iterator, epi_layout)

base_coord_n = mma_tile_coord_mnl[
1
] * self.cta_tile_shape_mnk[
1
] + real_subtile_idx * cute.size(tTR_rC)

if is_valid_row:
rOut_epi = cute.make_tensor(tTR_rC.iterator, epi_layout)

base_coord_n = mma_tile_coord_mnl[1] * self.cta_tile_shape_mnk[
1
] + real_subtile_idx * cute.size(tTR_rC)

for index in cutlass.range(
self.epi_loop_size, unroll_full=True
):
coord_n = base_coord_n + index * self.element_offset
scatter_out_offset = cute.domain_offset(
(0, coord_n, 0), scatter_out
scatter_out = cute.domain_offset(
(token_idx, 0, 0),
out, # Use original tensor to get real pointer
)
if cutlass.const_expr(self.out_dtype == cutlass.BFloat16):
rOut_epi_packed = rOut_epi[index, None, None]
vectorized_atomic_add_bf16x8(
rOut_epi_packed, scatter_out_offset
)
elif cutlass.const_expr(self.out_dtype == cutlass.Float32):
rOut_epi_packed = rOut_epi[index, None]
vectorized_atomic_add_fp32x2(
rOut_epi_packed, scatter_out_offset

for index in cutlass.range(
self.epi_loop_size, unroll_full=True
):
coord_n = base_coord_n + index * self.element_offset
scatter_out_offset = cute.domain_offset(
(0, coord_n, 0), scatter_out
)
else:
rOut_epi_packed = rOut_epi[index]
atomic_add_func(rOut_epi_packed, scatter_out_offset)
if cutlass.const_expr(
self.out_dtype == cutlass.BFloat16
):
rOut_epi_packed = rOut_epi[index, None, None]
vectorized_atomic_add_bf16x8(
rOut_epi_packed, scatter_out_offset
)
elif cutlass.const_expr(
self.out_dtype == cutlass.Float32
):
rOut_epi_packed = rOut_epi[index, None]
vectorized_atomic_add_fp32x2(
rOut_epi_packed, scatter_out_offset
)
else:
rOut_epi_packed = rOut_epi[index]
atomic_add_func(rOut_epi_packed, scatter_out_offset)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, let's examine the file and understand the context around the problematic lines
wc -l flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py

Repository: flashinfer-ai/flashinfer

Length of output: 167


🏁 Script executed:

# Check how c_smem_layout_staged is constructed
rg -n "c_smem_layout_staged" --type=py -C 3 flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py

Repository: flashinfer-ai/flashinfer

Length of output: 2161


🏁 Script executed:

# Look at subtile_cnt computation
rg -n "subtile_cnt" --type=py -C 3 flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py

Repository: flashinfer-ai/flashinfer

Length of output: 1035


🏁 Script executed:

# Check epi_tile dimensions
rg -n "epi_tile" --type=py -C 2 flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py

Repository: flashinfer-ai/flashinfer

Length of output: 2657


🏁 Script executed:

# Look at the partition_D call and understand how sC is structured
rg -n "partition_D" --type=py -C 5 flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py

Repository: flashinfer-ai/flashinfer

Length of output: 1391


🏁 Script executed:

# Look at the sC tensor construction and staging information
rg -n "sC.*=" --type=py flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 275


🏁 Script executed:

# Look at tTR_tAcc construction and its shape/modes
rg -n "tTR_tAcc" --type=py flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py -B 5 -A 5

Repository: flashinfer-ai/flashinfer

Length of output: 3865


🏁 Script executed:

# Look at the accumulator tensor memory copy partition to understand mode 3
rg -n "epilog_tmem_copy_and_partition" --type=py flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py -A 30

Repository: flashinfer-ai/flashinfer

Length of output: 3088


🏁 Script executed:

# Check num_c_stage value
rg -n "num_c_stage" --type=py flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py

Repository: flashinfer-ai/flashinfer

Length of output: 384


🏁 Script executed:

# Look at the actual epilog_smem_copy_and_partition function signature and how tRS_sC is used
sed -n '2155,2175p' flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py

Repository: flashinfer-ai/flashinfer

Length of output: 772


🏁 Script executed:

# Check the full context of where tRS_sC indexing happens - lines around 1972
sed -n '1960,1980p' flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py

Repository: flashinfer-ai/flashinfer

Length of output: 1143


🏁 Script executed:

# Look at what accumulator tensor layout is used in epilog_tmem_copy_and_partition
sed -n '2089,2150p' flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py

Repository: flashinfer-ai/flashinfer

Length of output: 2506


Block-reduction epilogue path: fix subtile indexing into tRS_sC.

The shared memory tensor sC is created with layout shape (cta_tile_M, cta_tile_N, num_c_stage=1) (line 593-594). After partition_D, tRS_sC preserves this 3-mode structure with the third dimension remaining size 1. However, the copy operation at line 1975 indexes this dimension as tRS_sC[(None, None, real_subtile_idx, None)] where real_subtile_idx ranges from 0 to subtile_cnt - 1. Since subtile_cnt (line 1936) is derived from the grouped accumulator dimensions and can exceed 1, indexing mode 2 with real_subtile_idx > 0 will access out of bounds.

Either allocate sC with a third dimension matching subtile_cnt, or store each subtile sequentially instead of indexing by subtile index.

🤖 Prompt for AI Agents
In
`@flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py`
around lines 1971 - 2018, The copy into shared tensor tRS_sC uses tRS_sC[(None,
None, real_subtile_idx, None)] but sC was allocated with a size-1 third mode
(after partition_D) so real_subtile_idx can go out of bounds; fix by making the
shared buffer's third dimension match subtile_cnt or by removing the subtile
index and storing subtiles sequentially. Concretely, either change sC allocation
(the tensor named sC used before partition_D) to have shape (..., subtile_cnt,
...) so partition_D yields tRS_sC with a third mode of length subtile_cnt, or
replace the indexed copy in the block under use_blkred (the tiled_copy_r2s ->
tRS_sC copy that uses real_subtile_idx) with code that writes each subtile
sequentially (store into tRS_sC without indexing by real_subtile_idx, e.g.,
advance the destination slice per subtile) so accesses never exceed the existing
size-1 third dimension.

Comment on lines +347 to +361
@dsl_user_op
def blk_reduce_fp16(dst_gemm, src_smem, size, loc=None, ip=None):
llvm.inline_asm(
None,
[
dst_gemm.iterator.llvm_ptr,
src_smem.iterator.llvm_ptr,
size.ir_value(),
],
"cp.reduce.async.bulk.global.shared::cta.bulk_group.noftz.f16 [$0], [$1], $2;",
"l,l,r",
has_side_effects=True,
loc=loc,
ip=ip,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

# Find the blk_reduce_bf16 function to compare
cat flashinfer/fused_moe/cute_dsl/blackwell/utils.py | head -n 400 | tail -n 100

Repository: flashinfer-ai/flashinfer

Length of output: 2857


🏁 Script executed:

# Search for blk_reduce_bf16 in the file to see the comparison
grep -n "blk_reduce_bf16" flashinfer/fused_moe/cute_dsl/blackwell/utils.py

Repository: flashinfer-ai/flashinfer

Length of output: 137


🏁 Script executed:

# Get more context around the blk_reduce_bf16 function
sed -n '330,380p' flashinfer/fused_moe/cute_dsl/blackwell/utils.py

Repository: flashinfer-ai/flashinfer

Length of output: 1479


🌐 Web query:

NVIDIA PTX ISA cp.reduce.async.bulk instruction syntax documentation

💡 Result:

The official NVIDIA documentation for cp.reduce.async.bulk is in the PTX ISA manual, section “9.7.9.25.4.2. Data Movement and Conversion Instructions: cp.reduce.async.bulk. It defines these syntax forms (shown here verbatim in structure, with line breaks added): [1]

// shared::cta -> shared::cluster (mbarrier completion)
cp.reduce.async.bulk.dst.src.completion_mechanism.redOp.type
              [dstMem], [srcMem], size, [mbar]

.dst =                  { .shared::cluster }
.src =                  { .shared::cta }
.completion_mechanism = { .mbarrier::complete_tx::bytes }
.redOp =                { .and, .or, .xor, .add, .inc, .dec, .min, .max }
.type =                 { .b32, .u32, .s32, .b64, .u64 }


// shared::cta -> global (bulk_group completion; optional L2 cache hint + cache-policy)
cp.reduce.async.bulk.dst.src.completion_mechanism{.level::cache_hint}.redOp.type
              [dstMem], [srcMem], size{, cache-policy}

.dst =                  { .global }
.src =                  { .shared::cta }
.completion_mechanism = { .bulk_group }
.level::cache_hint    = { .L2::cache_hint }
.type =                 { .f16, .bf16, .b32, .u32, .s32, .b64, .u64, .s64, .f32, .f64 }


// f16/bf16 add variant requiring .noftz
cp.reduce.async.bulk.dst.src.completion_mechanism{.level::cache_hint}.add.noftz.type
              [dstMem], [srcMem], size{, cache-policy}

.type =                 { .f16, .bf16 }

Notes also covered there: size must be a multiple of 16 and both addresses must be 16-byte aligned; instruction introduced in PTX ISA 8.0 and requires sm_90+. [1]

If you also want a C++-level reference, libcu++/CCCL lists and documents cp.reduce.async.bulk under its PTX instruction wrappers (CUDA 12.5 / CCCL 2.4.0). [2]

Sources: PTX ISA manual (Parallel Thread Execution) [1]; CCCL/libcu++ PTX instruction reference [2].


Add missing .add. reduction operation in blk_reduce_fp16 PTX instruction.

The inline assembly string "cp.reduce.async.bulk.global.shared::cta.bulk_group.noftz.f16" is missing the .add. operation specifier. According to the PTX ISA specification, the f16/bf16 add variant requires the syntax .add.noftz.type. Both blk_reduce_bf16 and blk_reduce_fp32 in the same file correctly include the .add. operation, but blk_reduce_fp16 omits it. This will cause a PTX assembly error at kernel compilation time.

Fix
-        "cp.reduce.async.bulk.global.shared::cta.bulk_group.noftz.f16 [$0], [$1], $2;",
+        "cp.reduce.async.bulk.global.shared::cta.bulk_group.add.noftz.f16 [$0], [$1], $2;",
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@dsl_user_op
def blk_reduce_fp16(dst_gemm, src_smem, size, loc=None, ip=None):
llvm.inline_asm(
None,
[
dst_gemm.iterator.llvm_ptr,
src_smem.iterator.llvm_ptr,
size.ir_value(),
],
"cp.reduce.async.bulk.global.shared::cta.bulk_group.noftz.f16 [$0], [$1], $2;",
"l,l,r",
has_side_effects=True,
loc=loc,
ip=ip,
)
`@dsl_user_op`
def blk_reduce_fp16(dst_gemm, src_smem, size, loc=None, ip=None):
llvm.inline_asm(
None,
[
dst_gemm.iterator.llvm_ptr,
src_smem.iterator.llvm_ptr,
size.ir_value(),
],
"cp.reduce.async.bulk.global.shared::cta.bulk_group.add.noftz.f16 [$0], [$1], $2;",
"l,l,r",
has_side_effects=True,
loc=loc,
ip=ip,
)
🤖 Prompt for AI Agents
In `@flashinfer/fused_moe/cute_dsl/blackwell/utils.py` around lines 347 - 361, The
PTX inline-asm string in blk_reduce_fp16 is missing the ".add." reduction
opcode; update the assembly template in blk_reduce_fp16 (the llvm.inline_asm
call) to include the add specifier so it matches the other reducers (e.g., use
"cp.reduce.async.bulk.global.shared::cta.bulk_group.add.noftz.f16" or the same
pattern used by blk_reduce_bf16/blk_reduce_fp32), leaving the argument list,
constraint string, and call site (dst_gemm.iterator.llvm_ptr,
src_smem.iterator.llvm_ptr, size.ir_value(), loc, ip) unchanged.

@flashinfer-bot
Copy link
Collaborator

[CANCELING] Pipeline #43643225: canceled

@nv-yunzheq
Copy link
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !303 has been created, and the CI pipeline #43648436 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #43648436: 14/20 passed

Copy link
Collaborator

@aleozlx aleozlx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@aleozlx
Copy link
Collaborator

aleozlx commented Feb 10, 2026

tests seem clean. ok to merge

@yongwww
Copy link
Member

yongwww commented Feb 11, 2026

@nv-yunzheq we can rebase the PR onto the latest main to kick off public CI

@nv-yunzheq
Copy link
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !303 has been updated with latest changes, and the CI pipeline #43802801 is currently running. I'll report back once the pipeline job completes.

@nv-yunzheq
Copy link
Collaborator Author

@flashinfer-bot run

@nv-yunzheq
Copy link
Collaborator Author

@nv-yunzheq we can rebase the PR onto the latest main to kick off public CI

@yongwww What is the command to kick off public CI?

@flashinfer-bot
Copy link
Collaborator

[CANCELING] Pipeline #43802801: canceled

@yongwww yongwww added the run-ci label Feb 11, 2026
@yongwww
Copy link
Member

yongwww commented Feb 11, 2026

@nv-yunzheq we can either add a label run-ci or leave a comment @flashinfer-bot run to trigger the public ci. Will add them in a doc.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants