Chore: Cute dsl moe update (TMA.RED implementation)#2529
Chore: Cute dsl moe update (TMA.RED implementation)#2529nv-yunzheq wants to merge 3 commits intoflashinfer-ai:mainfrom
Conversation
📝 WalkthroughWalkthroughIntroduces block-reduction optimization in Blackwell GEMM finalize kernel via new Changes
Sequence DiagramsequenceDiagram
participant Kernel as Kernel Launch
participant Epilogue as Epilogue Handler
participant RegMem as Register Memory
participant SmemCopy as Smem Copy/Partition
participant BlockRed as Block-Reduce Op
participant GemmMem as GEMM Memory
Kernel->>Epilogue: Invoke epilogue with tAcc data
alt use_blkred enabled
Epilogue->>RegMem: Load tTR_rC from registers
Epilogue->>SmemCopy: Call epilog_smem_copy_and_partition
SmemCopy->>SmemCopy: Partition R2S tiled copy
SmemCopy-->>Epilogue: Return partitioned copy & tensors
Epilogue->>SmemCopy: Store tTR_rC to sC (shared memory)
Epilogue->>Epilogue: Fence & synchronize
SmemCopy->>BlockRed: Prepare block-reduce input from sC
BlockRed->>GemmMem: Execute reduce.async.bulk (bf16/fp32/fp16)
GemmMem-->>BlockRed: Accumulate to destination
else use_blkred disabled
Epilogue->>Epilogue: Execute per-subtile epilogue path (original)
end
Epilogue-->>Kernel: Return with finalized output
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @nv-yunzheq, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the performance of the Cute DSL MoE kernel on Blackwell architecture by implementing TMA.RED. The core change involves introducing a block reduction mechanism that optimizes memory bandwidth during the epilogue phase of GEMM operations. This leads to measurable speedups, especially for larger token sizes, and refines the shared memory management within the kernel. Additionally, the testing infrastructure has been improved by shifting reference computations to the GPU. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
/bot run |
There was a problem hiding this comment.
Code Review
This pull request introduces TMA.RED (Tensor Memory Access with Reduction) to improve performance in the CuteDSL MoE kernel for Blackwell GPUs. The changes are quite extensive, adding a new code path for block reduction in the epilogue, modifying shared memory layouts, and updating staging logic. The tests have also been improved to run the reference implementation on the GPU, which is a welcome change. I've identified a couple of potential issues, one of which is critical, that should be addressed to ensure correctness.
| @dsl_user_op | ||
| def blk_reduce_fp16(dst_gemm, src_smem, size, loc=None, ip=None): | ||
| llvm.inline_asm( | ||
| None, | ||
| [ | ||
| dst_gemm.iterator.llvm_ptr, | ||
| src_smem.iterator.llvm_ptr, | ||
| size.ir_value(), | ||
| ], | ||
| "cp.reduce.async.bulk.global.shared::cta.bulk_group.noftz.f16 [$0], [$1], $2;", | ||
| "l,l,r", | ||
| has_side_effects=True, | ||
| loc=loc, | ||
| ip=ip, | ||
| ) |
There was a problem hiding this comment.
The inline assembly for blk_reduce_fp16 appears to be missing the .add operation specifier. The instruction cp.reduce.async.bulk.global.shared::cta.bulk_group.noftz.f16 is likely incorrect. Based on the patterns for bf16 and fp32 and PTX documentation, it should be cp.reduce.async.bulk.global.shared::cta.bulk_group.add.noftz.f16. Without the .add, the reduction operation is not specified and will likely lead to incorrect results for fp16 data types.
| @dsl_user_op | |
| def blk_reduce_fp16(dst_gemm, src_smem, size, loc=None, ip=None): | |
| llvm.inline_asm( | |
| None, | |
| [ | |
| dst_gemm.iterator.llvm_ptr, | |
| src_smem.iterator.llvm_ptr, | |
| size.ir_value(), | |
| ], | |
| "cp.reduce.async.bulk.global.shared::cta.bulk_group.noftz.f16 [$0], [$1], $2;", | |
| "l,l,r", | |
| has_side_effects=True, | |
| loc=loc, | |
| ip=ip, | |
| ) | |
| @dsl_user_op | |
| def blk_reduce_fp16(dst_gemm, src_smem, size, loc=None, ip=None): | |
| llvm.inline_asm( | |
| None, | |
| [ | |
| dst_gemm.iterator.llvm_ptr, | |
| src_smem.iterator.llvm_ptr, | |
| size.ir_value(), | |
| ], | |
| "cp.reduce.async.bulk.global.shared::cta.bulk_group.add.noftz.f16 [$0], [$1], $2;", | |
| "l,l,r", | |
| has_side_effects=True, | |
| loc=loc, | |
| ip=ip, | |
| ) |
| swizzled_pad = 16 // (self.out_dtype.width // 8) | ||
| self.c_smem_layout_staged = cute.make_layout( | ||
| (self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[1], self.num_c_stage), | ||
| stride=( | ||
| self.cta_tile_shape_mnk[1] + swizzled_pad, | ||
| 1, | ||
| self.cta_tile_shape_mnk[0] * (self.cta_tile_shape_mnk[1] + 8), | ||
| ), | ||
| ) |
There was a problem hiding this comment.
The stride calculation for the stage dimension of c_smem_layout_staged appears to use a hardcoded padding of 8, while the padding for the M dimension correctly uses swizzled_pad. This can lead to incorrect memory access for data types where swizzled_pad is not 8 (e.g., float32 where it would be 4). For consistency and correctness, swizzled_pad should be used in both places.
| swizzled_pad = 16 // (self.out_dtype.width // 8) | |
| self.c_smem_layout_staged = cute.make_layout( | |
| (self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[1], self.num_c_stage), | |
| stride=( | |
| self.cta_tile_shape_mnk[1] + swizzled_pad, | |
| 1, | |
| self.cta_tile_shape_mnk[0] * (self.cta_tile_shape_mnk[1] + 8), | |
| ), | |
| ) | |
| swizzled_pad = 16 // (self.out_dtype.width // 8) | |
| self.c_smem_layout_staged = cute.make_layout( | |
| (self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[1], self.num_c_stage), | |
| stride=( | |
| self.cta_tile_shape_mnk[1] + swizzled_pad, | |
| 1, | |
| self.cta_tile_shape_mnk[0] * (self.cta_tile_shape_mnk[1] + swizzled_pad), | |
| ), | |
| ) |
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Fix all issues with AI agents
In
`@flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py`:
- Around line 592-600: The stage stride uses a hardcoded 8 instead of the
computed swizzled_pad leading to inconsistent layout sizes; update the third
element of the stride tuple in c_smem_layout_staged (created in make_layout) to
use (self.cta_tile_shape_mnk[1] + swizzled_pad) instead of
(self.cta_tile_shape_mnk[1] + 8) so the stage stride calculation is consistent
with swizzled_pad (referencing swizzled_pad, c_smem_layout_staged,
cta_tile_shape_mnk, num_c_stage).
- Around line 1971-2018: The copy into shared tensor tRS_sC uses tRS_sC[(None,
None, real_subtile_idx, None)] but sC was allocated with a size-1 third mode
(after partition_D) so real_subtile_idx can go out of bounds; fix by making the
shared buffer's third dimension match subtile_cnt or by removing the subtile
index and storing subtiles sequentially. Concretely, either change sC allocation
(the tensor named sC used before partition_D) to have shape (..., subtile_cnt,
...) so partition_D yields tRS_sC with a third mode of length subtile_cnt, or
replace the indexed copy in the block under use_blkred (the tiled_copy_r2s ->
tRS_sC copy that uses real_subtile_idx) with code that writes each subtile
sequentially (store into tRS_sC without indexing by real_subtile_idx, e.g.,
advance the destination slice per subtile) so accesses never exceed the existing
size-1 third dimension.
In `@flashinfer/fused_moe/cute_dsl/blackwell/utils.py`:
- Around line 347-361: The PTX inline-asm string in blk_reduce_fp16 is missing
the ".add." reduction opcode; update the assembly template in blk_reduce_fp16
(the llvm.inline_asm call) to include the add specifier so it matches the other
reducers (e.g., use
"cp.reduce.async.bulk.global.shared::cta.bulk_group.add.noftz.f16" or the same
pattern used by blk_reduce_bf16/blk_reduce_fp32), leaving the argument list,
constraint string, and call site (dst_gemm.iterator.llvm_ptr,
src_smem.iterator.llvm_ptr, size.ir_value(), loc, ip) unchanged.
🧹 Nitpick comments (2)
flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py (1)
206-214:use_blkred=Trueis hardcoded but not included in the cache key.Since
use_blkredis alwaysTruehere, this is currently fine. However, if this is ever made configurable (e.g., exposed as a parameter to the public API or toggled based on heuristics), the cache on line 204 would serve stale kernels compiled with a differentuse_blkredvalue.Consider either adding
use_blkredto the cache key now, or leaving a comment noting the constraint.tests/moe/test_cute_dsl_fused_moe.py (1)
38-57: Useflashinfer.utils.get_compute_capabilityinstead oftorch.cuda.get_device_properties.The coding guidelines require test files under
tests/to useflashinfer.utilsfunctions for GPU architecture checks. Theis_sm100_family()function reimplements this check using raw PyTorch APIs. The sameget_compute_capabilityutility is already used in production code (e.g.,blockscaled_contiguous_grouped_gemm_finalize_fusion.pyline 368).♻️ Proposed fix
+from flashinfer.utils import get_compute_capability + + def is_sm100_family(): """Check for SM100 family (Blackwell: SM100, SM103, SM110).""" if not torch.cuda.is_available(): return False - props = torch.cuda.get_device_properties(0) - return props.major == 10 + major, _ = get_compute_capability(torch.device("cuda:0")) + return major == 10As per coding guidelines,
tests/**/*.py: "Test files must use flashinfer.utils functions (get_compute_capability, is_sm90a_supported, etc.) to skip tests on unsupported GPU architectures".
| swizzled_pad = 16 // (self.out_dtype.width // 8) | ||
| self.c_smem_layout_staged = cute.make_layout( | ||
| (self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[1], self.num_c_stage), | ||
| stride=( | ||
| self.cta_tile_shape_mnk[1] + swizzled_pad, | ||
| 1, | ||
| self.cta_tile_shape_mnk[0] * (self.cta_tile_shape_mnk[1] + 8), | ||
| ), | ||
| ) |
There was a problem hiding this comment.
Hardcoded 8 in stage stride is inconsistent with swizzled_pad.
Line 596 computes the row stride as cta_tile_shape_mnk[1] + swizzled_pad, but line 598 uses cta_tile_shape_mnk[1] + 8 for the stage stride. For Float32 output (swizzled_pad = 4), the stage stride would be larger than actually needed, though currently harmless since num_c_stage = 1.
🔧 Suggested fix for consistency
stride=(
self.cta_tile_shape_mnk[1] + swizzled_pad,
1,
- self.cta_tile_shape_mnk[0] * (self.cta_tile_shape_mnk[1] + 8),
+ self.cta_tile_shape_mnk[0] * (self.cta_tile_shape_mnk[1] + swizzled_pad),
),🤖 Prompt for AI Agents
In
`@flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py`
around lines 592 - 600, The stage stride uses a hardcoded 8 instead of the
computed swizzled_pad leading to inconsistent layout sizes; update the third
element of the stride tuple in c_smem_layout_staged (created in make_layout) to
use (self.cta_tile_shape_mnk[1] + swizzled_pad) instead of
(self.cta_tile_shape_mnk[1] + 8) so the stage stride calculation is consistent
with swizzled_pad (referencing swizzled_pad, c_smem_layout_staged,
cta_tile_shape_mnk, num_c_stage).
| if cutlass.const_expr(self.use_blkred): | ||
| tRS_rC.store(acc_vec_final.to(self.out_dtype)) | ||
| if is_valid_row: | ||
| cute.copy( | ||
| tiled_copy_r2s, | ||
| tRS_rC, | ||
| tRS_sC[(None, None, real_subtile_idx, None)], | ||
| ) | ||
| else: | ||
| tTR_rC.store(acc_vec_final.to(self.out_dtype)) | ||
| if is_valid_row: | ||
| rOut_epi = cute.make_tensor(tTR_rC.iterator, epi_layout) | ||
|
|
||
| base_coord_n = mma_tile_coord_mnl[ | ||
| 1 | ||
| ] * self.cta_tile_shape_mnk[ | ||
| 1 | ||
| ] + real_subtile_idx * cute.size(tTR_rC) | ||
|
|
||
| if is_valid_row: | ||
| rOut_epi = cute.make_tensor(tTR_rC.iterator, epi_layout) | ||
|
|
||
| base_coord_n = mma_tile_coord_mnl[1] * self.cta_tile_shape_mnk[ | ||
| 1 | ||
| ] + real_subtile_idx * cute.size(tTR_rC) | ||
|
|
||
| for index in cutlass.range( | ||
| self.epi_loop_size, unroll_full=True | ||
| ): | ||
| coord_n = base_coord_n + index * self.element_offset | ||
| scatter_out_offset = cute.domain_offset( | ||
| (0, coord_n, 0), scatter_out | ||
| scatter_out = cute.domain_offset( | ||
| (token_idx, 0, 0), | ||
| out, # Use original tensor to get real pointer | ||
| ) | ||
| if cutlass.const_expr(self.out_dtype == cutlass.BFloat16): | ||
| rOut_epi_packed = rOut_epi[index, None, None] | ||
| vectorized_atomic_add_bf16x8( | ||
| rOut_epi_packed, scatter_out_offset | ||
| ) | ||
| elif cutlass.const_expr(self.out_dtype == cutlass.Float32): | ||
| rOut_epi_packed = rOut_epi[index, None] | ||
| vectorized_atomic_add_fp32x2( | ||
| rOut_epi_packed, scatter_out_offset | ||
|
|
||
| for index in cutlass.range( | ||
| self.epi_loop_size, unroll_full=True | ||
| ): | ||
| coord_n = base_coord_n + index * self.element_offset | ||
| scatter_out_offset = cute.domain_offset( | ||
| (0, coord_n, 0), scatter_out | ||
| ) | ||
| else: | ||
| rOut_epi_packed = rOut_epi[index] | ||
| atomic_add_func(rOut_epi_packed, scatter_out_offset) | ||
| if cutlass.const_expr( | ||
| self.out_dtype == cutlass.BFloat16 | ||
| ): | ||
| rOut_epi_packed = rOut_epi[index, None, None] | ||
| vectorized_atomic_add_bf16x8( | ||
| rOut_epi_packed, scatter_out_offset | ||
| ) | ||
| elif cutlass.const_expr( | ||
| self.out_dtype == cutlass.Float32 | ||
| ): | ||
| rOut_epi_packed = rOut_epi[index, None] | ||
| vectorized_atomic_add_fp32x2( | ||
| rOut_epi_packed, scatter_out_offset | ||
| ) | ||
| else: | ||
| rOut_epi_packed = rOut_epi[index] | ||
| atomic_add_func(rOut_epi_packed, scatter_out_offset) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's examine the file and understand the context around the problematic lines
wc -l flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.pyRepository: flashinfer-ai/flashinfer
Length of output: 167
🏁 Script executed:
# Check how c_smem_layout_staged is constructed
rg -n "c_smem_layout_staged" --type=py -C 3 flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.pyRepository: flashinfer-ai/flashinfer
Length of output: 2161
🏁 Script executed:
# Look at subtile_cnt computation
rg -n "subtile_cnt" --type=py -C 3 flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.pyRepository: flashinfer-ai/flashinfer
Length of output: 1035
🏁 Script executed:
# Check epi_tile dimensions
rg -n "epi_tile" --type=py -C 2 flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.pyRepository: flashinfer-ai/flashinfer
Length of output: 2657
🏁 Script executed:
# Look at the partition_D call and understand how sC is structured
rg -n "partition_D" --type=py -C 5 flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.pyRepository: flashinfer-ai/flashinfer
Length of output: 1391
🏁 Script executed:
# Look at the sC tensor construction and staging information
rg -n "sC.*=" --type=py flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 275
🏁 Script executed:
# Look at tTR_tAcc construction and its shape/modes
rg -n "tTR_tAcc" --type=py flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py -B 5 -A 5Repository: flashinfer-ai/flashinfer
Length of output: 3865
🏁 Script executed:
# Look at the accumulator tensor memory copy partition to understand mode 3
rg -n "epilog_tmem_copy_and_partition" --type=py flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py -A 30Repository: flashinfer-ai/flashinfer
Length of output: 3088
🏁 Script executed:
# Check num_c_stage value
rg -n "num_c_stage" --type=py flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.pyRepository: flashinfer-ai/flashinfer
Length of output: 384
🏁 Script executed:
# Look at the actual epilog_smem_copy_and_partition function signature and how tRS_sC is used
sed -n '2155,2175p' flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.pyRepository: flashinfer-ai/flashinfer
Length of output: 772
🏁 Script executed:
# Check the full context of where tRS_sC indexing happens - lines around 1972
sed -n '1960,1980p' flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.pyRepository: flashinfer-ai/flashinfer
Length of output: 1143
🏁 Script executed:
# Look at what accumulator tensor layout is used in epilog_tmem_copy_and_partition
sed -n '2089,2150p' flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.pyRepository: flashinfer-ai/flashinfer
Length of output: 2506
Block-reduction epilogue path: fix subtile indexing into tRS_sC.
The shared memory tensor sC is created with layout shape (cta_tile_M, cta_tile_N, num_c_stage=1) (line 593-594). After partition_D, tRS_sC preserves this 3-mode structure with the third dimension remaining size 1. However, the copy operation at line 1975 indexes this dimension as tRS_sC[(None, None, real_subtile_idx, None)] where real_subtile_idx ranges from 0 to subtile_cnt - 1. Since subtile_cnt (line 1936) is derived from the grouped accumulator dimensions and can exceed 1, indexing mode 2 with real_subtile_idx > 0 will access out of bounds.
Either allocate sC with a third dimension matching subtile_cnt, or store each subtile sequentially instead of indexing by subtile index.
🤖 Prompt for AI Agents
In
`@flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py`
around lines 1971 - 2018, The copy into shared tensor tRS_sC uses tRS_sC[(None,
None, real_subtile_idx, None)] but sC was allocated with a size-1 third mode
(after partition_D) so real_subtile_idx can go out of bounds; fix by making the
shared buffer's third dimension match subtile_cnt or by removing the subtile
index and storing subtiles sequentially. Concretely, either change sC allocation
(the tensor named sC used before partition_D) to have shape (..., subtile_cnt,
...) so partition_D yields tRS_sC with a third mode of length subtile_cnt, or
replace the indexed copy in the block under use_blkred (the tiled_copy_r2s ->
tRS_sC copy that uses real_subtile_idx) with code that writes each subtile
sequentially (store into tRS_sC without indexing by real_subtile_idx, e.g.,
advance the destination slice per subtile) so accesses never exceed the existing
size-1 third dimension.
| @dsl_user_op | ||
| def blk_reduce_fp16(dst_gemm, src_smem, size, loc=None, ip=None): | ||
| llvm.inline_asm( | ||
| None, | ||
| [ | ||
| dst_gemm.iterator.llvm_ptr, | ||
| src_smem.iterator.llvm_ptr, | ||
| size.ir_value(), | ||
| ], | ||
| "cp.reduce.async.bulk.global.shared::cta.bulk_group.noftz.f16 [$0], [$1], $2;", | ||
| "l,l,r", | ||
| has_side_effects=True, | ||
| loc=loc, | ||
| ip=ip, | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# Find the blk_reduce_bf16 function to compare
cat flashinfer/fused_moe/cute_dsl/blackwell/utils.py | head -n 400 | tail -n 100Repository: flashinfer-ai/flashinfer
Length of output: 2857
🏁 Script executed:
# Search for blk_reduce_bf16 in the file to see the comparison
grep -n "blk_reduce_bf16" flashinfer/fused_moe/cute_dsl/blackwell/utils.pyRepository: flashinfer-ai/flashinfer
Length of output: 137
🏁 Script executed:
# Get more context around the blk_reduce_bf16 function
sed -n '330,380p' flashinfer/fused_moe/cute_dsl/blackwell/utils.pyRepository: flashinfer-ai/flashinfer
Length of output: 1479
🌐 Web query:
NVIDIA PTX ISA cp.reduce.async.bulk instruction syntax documentation
💡 Result:
The official NVIDIA documentation for cp.reduce.async.bulk is in the PTX ISA manual, section “9.7.9.25.4.2. Data Movement and Conversion Instructions: cp.reduce.async.bulk”. It defines these syntax forms (shown here verbatim in structure, with line breaks added): [1]
// shared::cta -> shared::cluster (mbarrier completion)
cp.reduce.async.bulk.dst.src.completion_mechanism.redOp.type
[dstMem], [srcMem], size, [mbar]
.dst = { .shared::cluster }
.src = { .shared::cta }
.completion_mechanism = { .mbarrier::complete_tx::bytes }
.redOp = { .and, .or, .xor, .add, .inc, .dec, .min, .max }
.type = { .b32, .u32, .s32, .b64, .u64 }
// shared::cta -> global (bulk_group completion; optional L2 cache hint + cache-policy)
cp.reduce.async.bulk.dst.src.completion_mechanism{.level::cache_hint}.redOp.type
[dstMem], [srcMem], size{, cache-policy}
.dst = { .global }
.src = { .shared::cta }
.completion_mechanism = { .bulk_group }
.level::cache_hint = { .L2::cache_hint }
.type = { .f16, .bf16, .b32, .u32, .s32, .b64, .u64, .s64, .f32, .f64 }
// f16/bf16 add variant requiring .noftz
cp.reduce.async.bulk.dst.src.completion_mechanism{.level::cache_hint}.add.noftz.type
[dstMem], [srcMem], size{, cache-policy}
.type = { .f16, .bf16 }
Notes also covered there: size must be a multiple of 16 and both addresses must be 16-byte aligned; instruction introduced in PTX ISA 8.0 and requires sm_90+. [1]
If you also want a C++-level reference, libcu++/CCCL lists and documents cp.reduce.async.bulk under its PTX instruction wrappers (CUDA 12.5 / CCCL 2.4.0). [2]
Sources: PTX ISA manual (Parallel Thread Execution) [1]; CCCL/libcu++ PTX instruction reference [2].
Add missing .add. reduction operation in blk_reduce_fp16 PTX instruction.
The inline assembly string "cp.reduce.async.bulk.global.shared::cta.bulk_group.noftz.f16" is missing the .add. operation specifier. According to the PTX ISA specification, the f16/bf16 add variant requires the syntax .add.noftz.type. Both blk_reduce_bf16 and blk_reduce_fp32 in the same file correctly include the .add. operation, but blk_reduce_fp16 omits it. This will cause a PTX assembly error at kernel compilation time.
Fix
- "cp.reduce.async.bulk.global.shared::cta.bulk_group.noftz.f16 [$0], [$1], $2;",
+ "cp.reduce.async.bulk.global.shared::cta.bulk_group.add.noftz.f16 [$0], [$1], $2;",📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| @dsl_user_op | |
| def blk_reduce_fp16(dst_gemm, src_smem, size, loc=None, ip=None): | |
| llvm.inline_asm( | |
| None, | |
| [ | |
| dst_gemm.iterator.llvm_ptr, | |
| src_smem.iterator.llvm_ptr, | |
| size.ir_value(), | |
| ], | |
| "cp.reduce.async.bulk.global.shared::cta.bulk_group.noftz.f16 [$0], [$1], $2;", | |
| "l,l,r", | |
| has_side_effects=True, | |
| loc=loc, | |
| ip=ip, | |
| ) | |
| `@dsl_user_op` | |
| def blk_reduce_fp16(dst_gemm, src_smem, size, loc=None, ip=None): | |
| llvm.inline_asm( | |
| None, | |
| [ | |
| dst_gemm.iterator.llvm_ptr, | |
| src_smem.iterator.llvm_ptr, | |
| size.ir_value(), | |
| ], | |
| "cp.reduce.async.bulk.global.shared::cta.bulk_group.add.noftz.f16 [$0], [$1], $2;", | |
| "l,l,r", | |
| has_side_effects=True, | |
| loc=loc, | |
| ip=ip, | |
| ) |
🤖 Prompt for AI Agents
In `@flashinfer/fused_moe/cute_dsl/blackwell/utils.py` around lines 347 - 361, The
PTX inline-asm string in blk_reduce_fp16 is missing the ".add." reduction
opcode; update the assembly template in blk_reduce_fp16 (the llvm.inline_asm
call) to include the add specifier so it matches the other reducers (e.g., use
"cp.reduce.async.bulk.global.shared::cta.bulk_group.add.noftz.f16" or the same
pattern used by blk_reduce_bf16/blk_reduce_fp32), leaving the argument list,
constraint string, and call site (dst_gemm.iterator.llvm_ptr,
src_smem.iterator.llvm_ptr, size.ir_value(), loc, ip) unchanged.
|
[CANCELING] Pipeline #43643225: canceled |
|
/bot run |
|
[FAILED] Pipeline #43648436: 14/20 passed |
flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py
Show resolved
Hide resolved
|
tests seem clean. ok to merge |
|
@nv-yunzheq we can rebase the PR onto the latest main to kick off public CI |
|
/bot run |
|
@flashinfer-bot run |
@yongwww What is the command to kick off public CI? |
|
[CANCELING] Pipeline #43802801: canceled |
|
@nv-yunzheq we can either add a label |
📌 Description
The PR is follow up to PR #2398
To integration TRTLLM PR 10987. Use TMA.RED to improve effective memory bandwidth
Perf data is (tested on GB200):
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Release Notes
New Features
Performance