[CuTe DSL] Add modular FMHA prefill attention kernel#2805
[CuTe DSL] Add modular FMHA prefill attention kernel#2805pgera wants to merge 14 commits intoflashinfer-ai:mainfrom
Conversation
Add CuTe DSL-based attention implementation: - flashinfer/cute_dsl/attention/ - Modular attention package with composable roles (loader, softmax, MMA, epilogue), fusion points (logits transform, mask, output transform), and schedulers - flashinfer/cute_dsl/prefill.py - Batch prefill wrapper - flashinfer/cute_dsl/mla.py - MLA decode wrapper - flashinfer/cute_dsl/patch/pipeline.py - Pipeline patching utilities Tests and benchmarks (named to avoid conflicts with existing cutlass tests): - tests/test_blackwell_fmha_cutedsl.py - FMHA tests (prefill) - tests/test_blackwell_fmha_attention.py - Modular attention package tests - tests/test_blackwell_mla_attention.py - MLA attention tests - tests/test_deepseek_mla_cutedsl.py - DeepSeek MLA tests - benchmarks/bench_blackwell_attention_cutedsl.py - Attention benchmarks - docs/cutedsl_fmha_architecture.md - Architecture documentation Made-with: Cursor
- Delete flashinfer/cute_dsl/prefill.py and mla.py (replaced by the modular flashinfer/cute_dsl/attention/ package) - Delete tests/test_blackwell_fmha_cutedsl.py and tests/test_deepseek_mla_cutedsl.py (replaced by test_blackwell_fmha_attention.py and test_blackwell_mla_attention.py) - Revert benchmarks/bench_deepseek_mla.py to upstream version - Split benchmarks into prefill and decode: bench_blackwell_attention_cutedsl.py (FMHA prefill) bench_blackwell_mla_cutedsl.py (MLA decode) Made-with: Cursor
…rnels Two kernel correctness bugs fixed: 1. PV1(end) accumulate flag: The final PV GEMM for stage 1 used hardcoded accumulate=True, causing stale TMEM data corruption when the KV loop didn't execute (kv_len <= tile_size with multi-Q-tile batches). Fix: use pv_whether_acc instead of True. 2. Causal mask trip count: get_masked_trip_count used ceil_div(M, N) which doesn't account for non-zero causal_offset shifting the diagonal across extra KV tiles. When kv_len != qo_len, some tiles needing masking were processed as unmasked, leaking unmasked scores into softmax. Fix: compute masked tile count from actual diagonal boundary positions. Both fixes required threading seqlen_q through the mask functions and passing causal_offset to apply_mask. Test suite pruned to ~112 curated cases covering tile boundaries, GQA, varlen, causal, output/logits transforms, and attention sink. AI-assisted (Claude) Made-with: Cursor
…ate domain conversion The plan() method created a dummy sink tensor with hardcoded float16 dtype for JIT compilation regardless of input dtype. When bfloat16 inputs were used at runtime, the compiled kernel misinterpreted bf16 bits as fp16, garbling sink values (causal row-0 error: 1.75 -> 0.004). Also fix the sink_M_D_update test helper to properly convert the sink value from scaled-logit space to raw-logit space by dividing by scale, and tighten the sink test tolerance from atol=2.0 to atol=0.01. AI-assisted (Claude) Made-with: Cursor
…ve tests (AI-assisted) Kernel fixes: - Sliding window apply_mask: add missing left-bound check (|kv-q| > window) and seqlen_k bounds check - Sliding window get_trip_count/get_kv_start_block_idx: compute correct symmetric window tile range instead of right-only approximation - Softmax run(): add kv_start_offset to coordinate identity tensor so mask coordinates match actual KV positions loaded by the TMA loader Test fixes: - sink_M_D_update: add * scale to exp2 rescale terms for correctness (m is in RAW domain, exp2 needs domain conversion via * scale) - Sink test: use SM_SCALE=1/sqrt(head_dim) instead of 1.0, which made the sink contribution negligible (~0) and the test vacuous New test coverage: - float16 dtype (3 shapes x 2 causal) - Sliding window mask (4 window/shape combos) - head_dim=64 (3 shapes x 2 causal) - Variable-length + sigmoid logits transform (2 indptr patterns) - Variable-length + attention sink (2 indptr patterns) - Attention sink with MHA / num_kv_heads=32 (2 shapes x 2 causal) All 118 tests pass, 18 skipped (qo>kv+causal), ~10 min runtime. Made-with: Cursor
Strip out MLA decode kernel, config, warp schedule, roles, scheduler, wrapper, benchmark, test, and architecture doc to keep this PR focused on FMHA prefill only. Clean up MLA references in shared modules. AI-assisted Made-with: Cursor
📝 WalkthroughWalkthroughThis PR introduces a comprehensive CuTe DSL-based fused multi-head attention (FMHA) implementation for Blackwell GPUs. It adds modular attention kernel infrastructure with configurable variants, masking, pipeline topology, and a PyTorch wrapper for batch prefill operations, complemented by benchmarking and extensive testing. Changes
Suggested labels
Suggested reviewers
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment Tip CodeRabbit can generate a title for your PR based on the changes with custom instructions.Set the |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a refactored and improved FMHA prefill kernel using a modular design. The new implementation addresses several bugs present in the previous version and offers a more flexible and testable architecture. By composing roles and using declarative pipeline topologies, the kernel supports a wide range of configurations and customizations, making it suitable for various attention mechanisms. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a significant and well-structured modular rewrite of the FMHA prefill kernel using CuTe DSL. The new architecture, based on composable roles and declarative pipeline topologies, is a major improvement for maintainability and extensibility. The comprehensive test suite covers a wide range of configurations, ensuring the robustness of the new implementation. My review includes a few suggestions for code cleanup, improving maintainability by adding documentation, and addressing potential issues like unused code and a missing JIT decorator.
| def get_trip_count( | ||
| mask_type: MaskType, | ||
| window_left: int, | ||
| blk_coord: cute.Coord, | ||
| tile_shape: cute.Shape, | ||
| seqlen_k: Int32, | ||
| seqlen_q: Int32 = 0, | ||
| ) -> Int32: | ||
| """Number of KV tile blocks to process for this Q tile.""" | ||
| result = 0 | ||
| if mask_type == MaskType.NO_MASK or mask_type == MaskType.RESIDUAL_MASK: | ||
| result = cute.ceil_div(seqlen_k, tile_shape[1]) | ||
| elif mask_type == MaskType.CAUSAL_MASK: | ||
| max_blocks_k = cute.ceil_div(seqlen_k, tile_shape[1]) | ||
| causal_offset = seqlen_k - seqlen_q | ||
| max_blocks_q = cute.ceil_div( | ||
| (blk_coord[0] + 1) * tile_shape[0] + causal_offset, tile_shape[1] | ||
| ) | ||
| result = cutlass.min(max_blocks_k, max_blocks_q) | ||
| elif mask_type == MaskType.SLIDING_WINDOW_MASK: | ||
| first_q = blk_coord[0] * tile_shape[0] | ||
| last_q = (blk_coord[0] + 1) * tile_shape[0] - 1 | ||
| min_kv = cutlass.max(0, first_q - window_left) | ||
| max_kv = cutlass.min(seqlen_k - 1, last_q + window_left) | ||
| start_block = min_kv // tile_shape[1] | ||
| end_block = cute.ceil_div(max_kv + 1, tile_shape[1]) | ||
| result = end_block - start_block | ||
| return result |
There was a problem hiding this comment.
The function get_trip_count is called from other JIT-compiled functions (e.g., in loader_tma.py, mma.py, correction.py, and indirectly from softmax.py) but it is not decorated with @cute.jit. While this might be handled by the compiler through inlining, it's better to be explicit to ensure correctness and improve maintainability. Please add the @cute.jit decorator to this function.
@cute.jit
def get_trip_count(
mask_type: MaskType,
window_left: int,
blk_coord: cute.Coord,
tile_shape: cute.Shape,
seqlen_k: Int32,
seqlen_q: Int32 = 0,
) -> Int32:
"""Number of KV tile blocks to process for this Q tile."""
result = 0
if mask_type == MaskType.NO_MASK or mask_type == MaskType.RESIDUAL_MASK:
result = cute.ceil_div(seqlen_k, tile_shape[1])
elif mask_type == MaskType.CAUSAL_MASK:
max_blocks_k = cute.ceil_div(seqlen_k, tile_shape[1])
causal_offset = seqlen_k - seqlen_q
max_blocks_q = cute.ceil_div(
(blk_coord[0] + 1) * tile_shape[0] + causal_offset, tile_shape[1]
)
result = cutlass.min(max_blocks_k, max_blocks_q)
elif mask_type == MaskType.SLIDING_WINDOW_MASK:
first_q = blk_coord[0] * tile_shape[0]
last_q = (blk_coord[0] + 1) * tile_shape[0] - 1
min_kv = cutlass.max(0, first_q - window_left)
max_kv = cutlass.min(seqlen_k - 1, last_q + window_left)
start_block = min_kv // tile_shape[1]
end_block = cute.ceil_div(max_kv + 1, tile_shape[1])
result = end_block - start_block
return result| # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: BSD-3-Clause | ||
|
|
||
| """Shared TMEM utilities for compute roles. | ||
|
|
||
| Provides tmem_load_partition() — partitions TMEM output accumulator for | ||
| load/store by the rescale and epilogue roles. | ||
| """ | ||
|
|
||
| from types import SimpleNamespace | ||
|
|
||
| import cutlass | ||
| import cutlass.cute as cute | ||
| import cutlass.cute.nvgpu.tcgen05 as tcgen05 | ||
|
|
||
|
|
||
| @cute.jit | ||
| def tmem_load_partition( | ||
| tmem_ptr: cutlass.Int32, | ||
| tmem_o_offset: int, | ||
| acc_dtype: cutlass.Constexpr, | ||
| mma_pv_tiler: cutlass.Constexpr, | ||
| cluster_shape_mnk: cutlass.Constexpr, | ||
| warps_in_n: int, | ||
| num_compute_warps: int, | ||
| threads_per_warp: int, | ||
| common_params: SimpleNamespace, | ||
| tiled_mma_pv: cute.TiledMma, | ||
| iter_n: int, | ||
| ) -> tuple[ | ||
| cute.TiledMma, | ||
| cute.TiledMma, | ||
| cute.TiledMma, | ||
| cute.TiledMma, | ||
| cute.TiledMma, | ||
| cute.TiledMma, | ||
| ]: | ||
| tOtO_shape = tiled_mma_pv.partition_shape_C( | ||
| cute.select(mma_pv_tiler, mode=[0, 1]) | ||
| ) | ||
| tOtO = tiled_mma_pv.make_fragment_C(tOtO_shape) | ||
| tOtO_layout = cute.append( | ||
| tOtO.layout, | ||
| cute.make_layout( | ||
| common_params.L // mma_pv_tiler[1], | ||
| stride=mma_pv_tiler[1] // warps_in_n, | ||
| ), | ||
| ) | ||
| tOtO = cute.make_tensor(tmem_ptr + tmem_o_offset, tOtO_layout) | ||
| tOtO = tOtO[None, None, None, iter_n] | ||
|
|
||
| tAcc = tOtO[(None, None), 0, 0] | ||
|
|
||
| tmem_load_atom = cute.make_copy_atom( | ||
| tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), acc_dtype | ||
| ) | ||
| tmem_load_tiled_copy = tcgen05.make_tmem_copy(tmem_load_atom, tAcc) | ||
| tmem_load_thr_copy = tmem_load_tiled_copy.get_slice( | ||
| common_params.tidx % (num_compute_warps * threads_per_warp) | ||
| ) | ||
|
|
||
| cta_pv_tiler = ( | ||
| mma_pv_tiler[0] // cluster_shape_mnk[0], | ||
| mma_pv_tiler[1], | ||
| mma_pv_tiler[2], | ||
| ) | ||
| cta_pv_tiler_mn = cute.select(cta_pv_tiler, mode=[0, 1]) | ||
|
|
||
| gO = None | ||
| if cutlass.const_expr(common_params.mAccO is not None): | ||
| gO = cute.local_tile( | ||
| common_params.mAccO[None, common_params.blk_coord[3], None, None], | ||
| cta_pv_tiler_mn, | ||
| (common_params.blk_coord[0], iter_n, common_params.blk_coord[2]), | ||
| ) | ||
| cO = cute.local_tile( | ||
| cute.make_identity_tensor( | ||
| common_params.mAccO[ | ||
| None, common_params.blk_coord[3], None, None | ||
| ].shape | ||
| ), | ||
| cta_pv_tiler_mn, | ||
| (common_params.blk_coord[0], iter_n, common_params.blk_coord[2]), | ||
| ) | ||
| else: | ||
| gO = cute.local_tile( | ||
| common_params.mO, | ||
| cta_pv_tiler_mn, | ||
| (common_params.blk_coord[0], iter_n, common_params.blk_coord[2]), | ||
| ) | ||
| cO = cute.local_tile( | ||
| cute.make_identity_tensor(common_params.mO.shape), | ||
| cta_pv_tiler_mn, | ||
| (common_params.blk_coord[0], iter_n, common_params.blk_coord[2]), | ||
| ) | ||
|
|
||
| tTR_tAcc = tmem_load_thr_copy.partition_S(tAcc) | ||
| tTR_gO = tmem_load_thr_copy.partition_D(gO) | ||
| tTR_cO = tmem_load_thr_copy.partition_D(cO) | ||
| tTR_rAcc = cute.make_fragment_like(tTR_gO, acc_dtype) | ||
| return tmem_load_tiled_copy, tAcc, tTR_tAcc, tTR_gO, tTR_cO, tTR_rAcc |
| if s_k.shape[0] > 1: | ||
| for i in range(len(s_k)): | ||
| if s_k[i] % self._mma_tiler_mn[1] != 0: | ||
| self._mask_type = MaskType.RESIDUAL_MASK | ||
| else: | ||
| if s_k % self._mma_tiler_mn[1] != 0: | ||
| self._mask_type = MaskType.RESIDUAL_MASK |
There was a problem hiding this comment.
The logic to determine if RESIDUAL_MASK is needed can be simplified. The current implementation iterates over the s_k tensor and has a branch for s_k.shape[0] > 1 which is always taken since s_k is derived from kv_indptr and will be a 1D tensor. You can use torch.any for a more concise and efficient check.
if torch.any(s_k % self._mma_tiler_mn[1] != 0):
self._mask_type = MaskType.RESIDUAL_MASK| @@ -0,0 +1,419 @@ | |||
| from typing import Optional, Type, Tuple | |||
There was a problem hiding this comment.
This file provides a custom implementation of pipeline participants, which appears to be a patch on top of cutlass.pipeline. However, it's missing a file-level docstring explaining why this custom implementation is necessary and what it changes compared to the original. Adding a docstring would greatly improve maintainability and make it easier for other developers to understand the purpose of this module.
| def sink_M_D_update(params, kv_tile_idx, qo_head_idx, m, d, scale): | ||
| # m is in the RAW (unscaled) domain; convert sink from scaled-logit to RAW | ||
| log2_e = math.log2(math.exp(1.0)) | ||
| sink_raw = params.sink[qo_head_idx] * log2_e / scale if (kv_tile_idx == 0 and qo_head_idx < NUM_QO_HEADS) else -math.inf |
There was a problem hiding this comment.
The condition qo_head_idx < NUM_QO_HEADS is redundant because qo_head_idx is a grid coordinate over the heads dimension, which is sized to NUM_QO_HEADS. Therefore, qo_head_idx will always be less than NUM_QO_HEADS. You can simplify the expression.
| sink_raw = params.sink[qo_head_idx] * log2_e / scale if (kv_tile_idx == 0 and qo_head_idx < NUM_QO_HEADS) else -math.inf | |
| sink_raw = params.sink[qo_head_idx] * log2_e / scale if kv_tile_idx == 0 else -math.inf |
| @cute.jit | ||
| def sink_M_D_update(params, kv_tile_idx, qo_head_idx, m, d, scale): | ||
| log2_e = math.log2(math.exp(1.0)) | ||
| sink_raw = params.sink[qo_head_idx] * log2_e / scale if (kv_tile_idx == 0 and qo_head_idx < NUM_QO_HEADS) else -math.inf |
There was a problem hiding this comment.
The condition qo_head_idx < NUM_QO_HEADS is redundant because qo_head_idx is a grid coordinate over the heads dimension, which is sized to NUM_QO_HEADS. Therefore, qo_head_idx will always be less than NUM_QO_HEADS. You can simplify the expression.
| sink_raw = params.sink[qo_head_idx] * log2_e / scale if (kv_tile_idx == 0 and qo_head_idx < NUM_QO_HEADS) else -math.inf | |
| sink_raw = params.sink[qo_head_idx] * log2_e / scale if kv_tile_idx == 0 else -math.inf |
The PipelineProducer/PipelineConsumer wrappers are now available in cutlass.pipeline (nvidia-cutlass-dsl >= 4.3). Use them directly instead of maintaining a local copy. Pipeline creation uses defer_sync=True since the kernel handles barrier init/sync separately. Verified: no perf regression (< 1% noise), all 118 tests pass. AI-assisted Made-with: Cursor
…ed tmem_utils.py - Add missing @cute.jit decorator to get_trip_count for consistency with all other functions in mask.py - Remove tmem_utils.py which was MLA-specific dead code after MLA removal AI-assisted Made-with: Cursor
…AI-assisted) - Use torch.any() for concise RESIDUAL_MASK determination in batch_prefill.py - Remove always-true qo_head_idx < NUM_QO_HEADS condition in sink_M_D_update tests Made-with: Cursor
…d hook (AI-assisted) Replace the dual sink_iter + variant_data_iter kernel parameters with a single params_iter path. Variants now expose runtime tensor data via an extra_params property; the kernel binds it to self.params before calling any JIT method. Key changes: - AttentionVariant: remove use_attention_sink/variant_data_tensor, add extra_params property and score_mod hook - AttentionWithSink: take sink tensor in constructor instead of run() - AttentionFusion: replace use_attention_sink + has_variant_data with has_params + params_shape + params_strides - prefill.py/softmax.py: merge sink + variant_data into single params - batch_prefill.py: remove sink from run(), add contiguity assertion - Fix CuTe column-major vs PyTorch row-major layout mismatch for N-D params by deriving explicit strides from the PyTorch tensor - Add ALiBiAttention (1-D params), RPEAttention (2-D params), SoftCappingAttention (compile-time only) variant classes - Add ALiBi and RPE tests with tight tolerances Made-with: Cursor
Made-with: Cursor
Made-with: Cursor
Made-with: Cursor
…assisted) - Add can_implement() to AttentionConfig for early validation of MMA tile size, head_dim divisibility, and GQA repeat count - Add SMEM capacity check in prefill kernel to catch head_dim overruns with a clear error instead of a cryptic CUDA runtime error - Add _validate_run_inputs() to BatchPrefillCuteDSLWrapper for runtime dtype/device/shape consistency checks - Change MainloopSpec.resolve() to return a new object instead of mutating in place, avoiding subtle state bugs - Clarify docstrings for decode-reserved fields and pipeline ordering Made-with: Cursor
There was a problem hiding this comment.
Actionable comments posted: 9
🧹 Nitpick comments (13)
flashinfer/cute_dsl/attention/tmem_layout.py (1)
35-49: Consider extractingSM100_TMEM_CAPACITY_COLUMNSas a module-level constant.The SM100 TMEM capacity is a hardware characteristic that may be referenced elsewhere. Extracting it improves discoverability and avoids magic numbers.
Proposed refactor
+SM100_TMEM_CAPACITY_COLUMNS = 512 + + `@dataclass`(frozen=True) class TmemLayout: ... `@staticmethod` def from_config(config: AttentionConfig) -> TmemLayout: tile_m = config.mma_tiler[0] - SM100_TMEM_CAPACITY_COLUMNS = 512 return TmemLayout( ... alloc_cols=SM100_TMEM_CAPACITY_COLUMNS, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/cute_dsl/attention/tmem_layout.py` around lines 35 - 49, Extract the literal 512 used for TMEM capacity into a module-level constant (e.g., SM100_TMEM_CAPACITY_COLUMNS = 512) and replace the local variable in TmemLayout.from_config so the function uses that constant instead of a magic number; update the top of the module with the constant and ensure TmemLayout.from_config (which takes AttentionConfig and reads config.mma_tiler[0]) references the new constant for alloc_cols so other code can reuse the hardware-capacity value.flashinfer/cute_dsl/attention/scheduler/persistent.py (2)
38-45: Addstrict=Truetozip()for safer MLIR value reconstruction.In
__new_from_mlir_values__, thezip()call iterates over[self.is_persistent, self.problem_shape_mbh]andself._values_pos. If these lists have mismatched lengths (e.g., due to a maintenance error),zip()will silently truncate, potentially causing subtle bugs during MLIR reconstruction.Also, the
ipparameter is not forwarded to the newFmhaStaticTileSchedulerParamsinstance on line 45.Proposed fix
def __new_from_mlir_values__(self, values): obj_list = [] for obj, n_items in zip( - [self.is_persistent, self.problem_shape_mbh], self._values_pos + [self.is_persistent, self.problem_shape_mbh], self._values_pos, strict=True ): obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) values = values[n_items:] - return FmhaStaticTileSchedulerParams(*(tuple(obj_list)), loc=self._loc) + return FmhaStaticTileSchedulerParams(*(tuple(obj_list)), loc=self._loc, ip=self._ip)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/cute_dsl/attention/scheduler/persistent.py` around lines 38 - 45, In __new_from_mlir_values__ update the zip over [self.is_persistent, self.problem_shape_mbh] and self._values_pos to use zip(..., strict=True) to fail loudly on length mismatches, and when returning the FmhaStaticTileSchedulerParams instance forward the current object's ip parameter (pass loc=self._loc, ip=self.ip) so the new instance receives ip as well; this touches the __new_from_mlir_values__ method, the attributes self.is_persistent, self.problem_shape_mbh, self._values_pos, and the FmhaStaticTileSchedulerParams constructor call.
148-158: Hardcoded MLIR value count is fragile.The assertion
assert len(values) == 10couples the implementation to a specific MLIR representation. If any constituent object's MLIR value count changes, this will fail without a clear message.Consider deriving the expected count dynamically or providing a descriptive error message.
Proposed improvement
def __new_from_mlir_values__(self, values): - assert len(values) == 10 + expected = 3 + 1 + 3 + 3 # params(3) + work_idx(1) + blk_coord(3) + grid_shape(3) + assert len(values) == expected, f"Expected {expected} MLIR values, got {len(values)}" new_params = cutlass.new_from_mlir_values(self._params, values[0:3])🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/cute_dsl/attention/scheduler/persistent.py` around lines 148 - 158, The hardcoded assertion in __new_from_mlir_values__ (assert len(values) == 10) is fragile; change it to compute the expected MLIR value count by summing the MLIR-value counts of the constituent objects (self._params, self._current_work_linear_idx, self._blk_coord, self._grid_shape) using whatever helper/attribute your cutlass layer exposes (e.g., a mlir value count helper or by querying each object's MLIR representation), then compare len(values) to that computed total and raise a ValueError with a descriptive message if mismatched; update the slicing logic that builds new_params, new_current_work_linear_idx, new_blk_coord, and new_grid_shape to use those computed per-object counts instead of fixed indices so FmhaStaticTileScheduler construction remains correct.flashinfer/cute_dsl/attention/collective_builder.py (1)
163-186: Consider using a typed dataclass instead ofSimpleNamespacefor better IDE support.The returned
SimpleNamespacecontains 20+ fields. A typed dataclass orNamedTuplewould provide autocompletion and type checking for consumers.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/cute_dsl/attention/collective_builder.py` around lines 163 - 186, Replace the anonymous SimpleNamespace return with a typed dataclass (e.g., define `@dataclass` class AttentionCollective or AttentionCollectiveConfig) that declares typed fields for each symbol currently passed (qk_tiled_mma, pv_tiled_mma, tma_atom_q, tma_tensor_q, tma_atom_k, tma_tensor_k, tma_atom_v, tma_tensor_v, tma_atom_o, tma_tensor_o, q_smem_layout_staged, k_smem_layout_staged, p_tmem_layout_staged, v_smem_layout_staged, o_smem_layout_staged, SharedStorage, tma_copy_q_bytes, tma_copy_kv_bytes, cluster_shape_mnk, cluster_layout_vmnk, epi_tile, o_layout), add appropriate type hints (use typing.Any or more specific types if known), import dataclasses and typing, instantiate and return that dataclass instead of SimpleNamespace, and update any consumers to accept the new dataclass type for improved IDE autocompletion and type checking.benchmarks/bench_blackwell_attention_cutedsl.py (1)
7-8: Use the publicflashinfer.testingbenchmark helper.This benchmark already relies on the standard timing helper, but it pulls it from
flashinfer.testing.utils, which couples the script to a private module path.♻️ Suggested change
-from flashinfer.testing.utils import bench_gpu_time +from flashinfer.testing import bench_gpu_timeBased on learnings
Use flashinfer.testing.bench_gpu_time() for benchmarking kernels, preferring CUPTI timing with auto-fallback to CUDA events.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_blackwell_attention_cutedsl.py` around lines 7 - 8, The benchmark imports bench_gpu_time from a private path (flashinfer.testing.utils); update the import to use the public helper by replacing references to flashinfer.testing.utils with the public module flashinfer.testing and import bench_gpu_time from flashinfer.testing (i.e., use flashinfer.testing.bench_gpu_time) so the benchmark relies on the supported public API rather than a private module.tests/test_blackwell_fmha_attention.py (1)
1-13: Please move this suite under a feature-specific tests subdirectory.This is kernel-specific CuTe DSL attention coverage, but the new module sits at
tests/root. Putting it under a matching subdirectory keeps the test surface organized with the rest of the kernel-category suites.As per coding guidelines
tests/**/*.py: Prefix test functions withtest_and structure tests by feature intests/subdirectories matching kernel categories.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/test_blackwell_fmha_attention.py` around lines 1 - 13, The test module test_blackwell_fmha_attention.py is at the tests/ root but belongs in the attention-specific kernel tests; move this suite into a feature-specific tests subdirectory matching the kernel category (e.g., an attention/ or blackwell_fmha/ tests folder), update any relative imports inside the module to the new location, and ensure all test callables in the file are properly prefixed with test_ so pytest discovers them (check function names and any parametrized fixtures used by functions in this module).flashinfer/cute_dsl/attention/wrappers/batch_prefill.py (3)
393-396: Addstrict=Truetozip()for early shape-mismatch detection.Using
strict=Truecatches mismatched lengths betweenpaddingandshape_early, improving debuggability.Suggested fix
- slices = tuple(slice(s, e) for s, e in zip(padding, shape_)) + slices = tuple(slice(s, e) for s, e in zip(padding, shape_, strict=True))🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/cute_dsl/attention/wrappers/batch_prefill.py` around lines 393 - 396, The slice construction using zip(padding, shape_) can silently ignore length mismatches; update the tuple comprehension that defines slices (used to create torch_tensor from torch_tensor_full and assigned to torch_tensor) to call zip with strict=True (i.e., zip(padding, shape_, strict=True)) so any mismatch between padding and shape_ raises immediately and makes debugging easier.
129-157: Prefix unused unpacked variables with underscore.The variables
q_ref,q_torch,k_ref,k_torch,v_ref,v_torch, ando_torchfromcreate_and_pad_tensor()are intentionally unused (they're dummy tensors for CuTe JIT tracing). Prefix them with_to indicate intent and silence linter warnings.Suggested fix
- q_ref, q_cute, q_torch = create_and_pad_tensor( + _q_ref, q_cute, _q_torch = create_and_pad_tensor( qo_shape, ... ) - k_ref, k_cute, k_torch = create_and_pad_tensor( + _k_ref, k_cute, _k_torch = create_and_pad_tensor( kv_shape, ... ) - v_ref, v_cute, v_torch = create_and_pad_tensor( + _v_ref, v_cute, _v_torch = create_and_pad_tensor( kv_shape, ... ) - _, o_cute, o_torch = create_and_pad_tensor( + _, o_cute, _o_torch = create_and_pad_tensor( qo_shape, ... )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/cute_dsl/attention/wrappers/batch_prefill.py` around lines 129 - 157, The unpacked dummy tensors returned by create_and_pad_tensor (q_ref, q_torch, k_ref, k_torch, v_ref, v_torch, o_torch) are unused and should be prefixed with an underscore to indicate intentional unused variables and silence linters; update the unpacking lines where create_and_pad_tensor is called (for q_, k_, v_, and o_) to rename those specific variables to _q_ref/_q_torch, _k_ref/_k_torch, _v_ref/_v_torch, and _o_torch (or similar underscore-prefixed names) while keeping the used names q_cute/k_cute/v_cute/o_cute unchanged.
318-319: Minor:device=q.deviceis redundant withtorch.empty_like.
torch.empty_like(q, ...)already inheritsq's device by default.Suggested fix
- out = torch.empty_like(q, device=q.device) + out = torch.empty_like(q)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/cute_dsl/attention/wrappers/batch_prefill.py` around lines 318 - 319, In batch_prefill.py replace the redundant explicit device argument when creating the empty tensor so that out is created with torch.empty_like(q) instead of torch.empty_like(q, device=q.device); locate the assignment that sets out when out is None (the one referencing variables out and q) and remove the device=q.device parameter to rely on torch.empty_like inheriting q's device.flashinfer/cute_dsl/attention/roles/softmax.py (1)
336-344: Redundantthread_idxcomputation.
thread_idxis computed identically at lines 337-344 and again at lines 366-373. The second computation overwrites the first with the same value.Remove duplicate computation
thread_idx = tidx % ( self.threads_per_warp * ( len(self.softmax0_warp_ids) if stage == 0 else len(self.softmax1_warp_ids) ) ) ... tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStSi) - thread_idx = tidx % ( - self.threads_per_warp - * ( - len(self.softmax0_warp_ids) - if stage == 0 - else len(self.softmax1_warp_ids) - ) - ) thr_tmem_load = tiled_tmem_load.get_slice(thread_idx)Also applies to: 366-373
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/cute_dsl/attention/roles/softmax.py` around lines 336 - 344, The duplicated computation of thread_idx (calling cute.arch.thread_idx(), taking tidx and computing tidx % (self.threads_per_warp * (len(self.softmax0_warp_ids) if stage == 0 else len(self.softmax1_warp_ids)))) appears twice; remove the redundant second block (the one at lines 366-373) so thread_idx remains computed once and subsequent code uses the already-computed thread_idx from the first occurrence; ensure any references after the removed block still rely on the existing thread_idx variable and that no logic dependent on re-calling cute.arch.thread_idx() is lost.flashinfer/cute_dsl/attention/prefill.py (3)
155-156: Prefix unuseds_kwith underscore.
s_kis unpacked but never used. Prefix with_to indicate intent.- b, s_q, s_k, h_q, h_k, d = problem_size + b, s_q, _s_k, h_q, h_k, d = problem_size🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/cute_dsl/attention/prefill.py` around lines 155 - 156, The tuple unpacking of problem_size currently binds an unused variable s_k; change the unpacking to use _s_k (or simply _ ) instead of s_k to signal it's intentionally unused (e.g., replace "b, s_q, s_k, h_q, h_k, d = problem_size" with an unpacking that prefixes s_k with an underscore) in the prefill logic where variables b, s_q, h_q, h_k, d are used and h_r is computed from h_q and h_k.
45-51: Overly broad warning suppression may hide legitimate issues.Suppressing all
UserWarningmessages (line 51) could mask important warnings from other parts of the codebase or dependencies. Consider scoping the suppression more narrowly, or applying it only within the specific context where the unrolling warning occurs.Alternative: use a context manager at call sites
# Remove the global filter at module level # warnings.filterwarnings("ignore", category=UserWarning) # Instead, wrap specific calls that generate the warning: import contextlib `@contextlib.contextmanager` def suppress_loop_unroll_warning(): with warnings.catch_warnings(): warnings.filterwarnings( "ignore", message="This loop is no longer unrolled and may cause performance regression", ) yield🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/cute_dsl/attention/prefill.py` around lines 45 - 51, The module currently suppresses all UserWarning globally by calling warnings.filterwarnings("ignore", category=UserWarning); remove that broad module-level filter and instead scope suppression to only the specific unroll warning by introducing a context manager (e.g., suppress_loop_unroll_warning using warnings.catch_warnings and warnings.filterwarnings with message="This loop is no longer unrolled and may cause performance regression") and use that context manager at the specific call sites in prefill.py where the unrolling warning is raised so other UserWarnings remain visible.
385-386: Prefix unusedtidxwith underscore.
tidxfromthread_idx()is unpacked but unused in the kernel entry. The variable is only used by roles that callthread_idx()themselves.- tidx, _, _ = cute.arch.thread_idx() + _tidx, _, _ = cute.arch.thread_idx()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/cute_dsl/attention/prefill.py` around lines 385 - 386, The unpacked thread index variable tidx from cute.arch.thread_idx() is unused in the kernel entry; change its name to _tidx to mark it as intentionally unused (i.e., replace "tidx, _, _ = cute.arch.thread_idx()" with "_tidx, _, _ = cute.arch.thread_idx()") so linters/readers know it's unused while keeping the other unpacked values and the existing warp_idx assignment (warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())) intact.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@benchmarks/bench_blackwell_attention_cutedsl.py`:
- Around line 153-161: The script currently unconditionally runs an SM100-only
kernel in the __main__ block (calls to bench_fmha_cutedsl), which will
JIT/launch-fail on non-SM100 GPUs; add a GPU capability check before running the
default sweep: use torch.cuda.is_available() and
torch.cuda.get_device_capability() or
torch.cuda.get_device_properties(device).major/minor (or device name) to detect
whether the current GPU supports SM100, and if not, skip the default
bench_fmha_cutedsl(...) calls and exit or print a clear message; update the
__main__ section so the SM100-only sweep only runs when the capability check
passes.
In `@flashinfer/cute_dsl/attention/collective_builder.py`:
- Around line 96-98: The p_tmem_layout_staged is being created with the wrong
dtype (q_dtype) causing a mismatch with pv_tiled_mma which was created for V;
update the call to sm100_utils.make_smem_layout_a in collective_builder so
p_tmem_layout_staged uses v_dtype instead of q_dtype (the call that takes
pv_tiled_mma, config.pv_mma_tiler, q_dtype, mainloop.acc_stage should pass
v_dtype) to align the P buffer TMEM layout with the V buffer.
In `@flashinfer/cute_dsl/attention/fusion/mask.py`:
- Around line 45-53: Sliding-window masking currently centers the KV window on
raw Q indices (see MaskType.SLIDING_WINDOW_MASK handling using blk_coord,
tile_shape, window_left, seqlen_k) and ignores the Q/K length offset used by the
causal path; compute q_k_offset = seqlen_k - seqlen_q and add it to first_q and
last_q (or otherwise shift Q indices into KV space) before calculating min_kv,
max_kv, start_block, end_block, and any element masks; apply the same fix to the
other sliding-window blocks noted (around the other occurrences at the given
ranges) so all sliding-window computations use shifted Q indices into KV
coordinate space.
In `@flashinfer/cute_dsl/attention/fusion/variant.py`:
- Around line 551-554: SoftCappingAttention.score_mod calls non-existent
cute.arch.tanh; replace it with a local tanh approximation implemented using
available primitives (e.g., cute.arch.exp2 and cute.arch.rcp_approx) or a cheap
rational polynomial and call that from score_mod. Add a helper function (e.g.,
_tanh_approx(x)) in the same class or module and use it in
SoftCappingAttention.score_mod (referencing self.cap and self.rcp_cap as
before), implementing tanh(x) via exp2 by computing exp(-2*abs(x)) with
exp2(-2*abs(x)/ln2) plus sign handling or by a stable rational approximation
(polynomial numerator/denominator) and ensure the helper uses
cute.jit-compatible operations only.
- Around line 367-378: Update the class and relevant parameter docstrings to
state that sink values are expected in the logit domain (raw Q·K dot-product
units, unnormalized), not pre-scaled to log2; specifically mention this near the
documentation for the sink parameter(s) used by update_statistics and the
self.params/sink_raw conversion (which divides by scale/log2_e), and add a
cross-reference to sink_softmax in sink_attention_reference.py so callers know
sinks are concatenated to logits before any log2 scaling.
In `@flashinfer/cute_dsl/attention/pipeline_topology.py`:
- Around line 68-79: The Pipeline.dataclass field cluster_scale is ignored by
create_pipelines(), causing incorrect participant and barrier arrive counts;
either (preferred) honor it by multiplying the all-thread side's participant
counts when constructing producer/consumer groups and computing barrier arrive
counts for PipelineType values UMMA_ASYNC and ASYNC_UMMA (but leave TMA_UMMA
unchanged), i.e., when building groups from producer_warp_ids/consumer_warp_ids
in create_pipelines() multiply the thread counts by pipeline.cluster_scale and
use that scaled value when setting arrive counts for barriers/tx_count_key, or
fail fast by adding a check in create_pipelines() that raises a clear exception
if pipeline.cluster_scale != 1 so callers must handle scaling explicitly.
In `@flashinfer/cute_dsl/attention/roles/epilogue.py`:
- Around line 41-66: partition_output is incorrectly decorated with `@cute.jit`
while returning tensor objects (tOsO, tOgO) which violates the CuTe JIT
limitation; either remove the `@cute.jit` decorator from partition_output so it
runs as a normal Python method, or refactor it to avoid returning tensors by (a)
accepting preallocated output containers/handles and writing into them, or (b)
moving the cute.nvgpu.cpasync.tma_partition call out of the `@cute.jit` function
into a non-jit wrapper (e.g., create partition_output_nonjit that calls
cute.nvgpu.cpasync.tma_partition and returns tensors or change partition_output
to populate passed-in tensor references); update references to partition_output
accordingly so no `@cute.jit` function returns tensors (symbols: partition_output,
tOsO, tOgO, tma_partition, tma_atom_o).
In `@flashinfer/cute_dsl/attention/warp_schedule.py`:
- Around line 17-71: Add a fail-fast validation in WarpSchedule (implement in a
__post_init__ method) that verifies: 1) all_warp_ids (built from
softmax0_warp_ids, softmax1_warp_ids, correction_warp_ids, mma_warp_id,
load_warp_id, epilogue_warp_id, empty_warp_id) contain unique values and form a
contiguous range starting at 0 up to len(all_warp_ids)-1, and 2) the total
number of softmax warps (len(softmax0_warp_ids)+len(softmax1_warp_ids)) is
divisible by num_warps_per_warpgroup; on violation raise ValueError with a clear
message referencing the failing condition so consumers of num_warps,
threads_per_cta, and softmax_warpgroup_count cannot silently compute incorrect
sizes.
In `@flashinfer/cute_dsl/attention/wrappers/batch_prefill.py`:
- Around line 159-169: The NameError risk comes from params_cute being defined
only inside the if self._has_params block yet referenced later; fix by defining
params_cute = None before the if and only assigning it inside the block (where
you call from_dlpack) so later code can safely use the conditional expression
(params_cute.iterator if self._has_params else None); update references
involving self._has_params, _params_torch, and from_dlpack accordingly to rely
on the initialized params_cute variable.
---
Nitpick comments:
In `@benchmarks/bench_blackwell_attention_cutedsl.py`:
- Around line 7-8: The benchmark imports bench_gpu_time from a private path
(flashinfer.testing.utils); update the import to use the public helper by
replacing references to flashinfer.testing.utils with the public module
flashinfer.testing and import bench_gpu_time from flashinfer.testing (i.e., use
flashinfer.testing.bench_gpu_time) so the benchmark relies on the supported
public API rather than a private module.
In `@flashinfer/cute_dsl/attention/collective_builder.py`:
- Around line 163-186: Replace the anonymous SimpleNamespace return with a typed
dataclass (e.g., define `@dataclass` class AttentionCollective or
AttentionCollectiveConfig) that declares typed fields for each symbol currently
passed (qk_tiled_mma, pv_tiled_mma, tma_atom_q, tma_tensor_q, tma_atom_k,
tma_tensor_k, tma_atom_v, tma_tensor_v, tma_atom_o, tma_tensor_o,
q_smem_layout_staged, k_smem_layout_staged, p_tmem_layout_staged,
v_smem_layout_staged, o_smem_layout_staged, SharedStorage, tma_copy_q_bytes,
tma_copy_kv_bytes, cluster_shape_mnk, cluster_layout_vmnk, epi_tile, o_layout),
add appropriate type hints (use typing.Any or more specific types if known),
import dataclasses and typing, instantiate and return that dataclass instead of
SimpleNamespace, and update any consumers to accept the new dataclass type for
improved IDE autocompletion and type checking.
In `@flashinfer/cute_dsl/attention/prefill.py`:
- Around line 155-156: The tuple unpacking of problem_size currently binds an
unused variable s_k; change the unpacking to use _s_k (or simply _ ) instead of
s_k to signal it's intentionally unused (e.g., replace "b, s_q, s_k, h_q, h_k, d
= problem_size" with an unpacking that prefixes s_k with an underscore) in the
prefill logic where variables b, s_q, h_q, h_k, d are used and h_r is computed
from h_q and h_k.
- Around line 45-51: The module currently suppresses all UserWarning globally by
calling warnings.filterwarnings("ignore", category=UserWarning); remove that
broad module-level filter and instead scope suppression to only the specific
unroll warning by introducing a context manager (e.g.,
suppress_loop_unroll_warning using warnings.catch_warnings and
warnings.filterwarnings with message="This loop is no longer unrolled and may
cause performance regression") and use that context manager at the specific call
sites in prefill.py where the unrolling warning is raised so other UserWarnings
remain visible.
- Around line 385-386: The unpacked thread index variable tidx from
cute.arch.thread_idx() is unused in the kernel entry; change its name to _tidx
to mark it as intentionally unused (i.e., replace "tidx, _, _ =
cute.arch.thread_idx()" with "_tidx, _, _ = cute.arch.thread_idx()") so
linters/readers know it's unused while keeping the other unpacked values and the
existing warp_idx assignment (warp_idx =
cute.arch.make_warp_uniform(cute.arch.warp_idx())) intact.
In `@flashinfer/cute_dsl/attention/roles/softmax.py`:
- Around line 336-344: The duplicated computation of thread_idx (calling
cute.arch.thread_idx(), taking tidx and computing tidx % (self.threads_per_warp
* (len(self.softmax0_warp_ids) if stage == 0 else len(self.softmax1_warp_ids))))
appears twice; remove the redundant second block (the one at lines 366-373) so
thread_idx remains computed once and subsequent code uses the already-computed
thread_idx from the first occurrence; ensure any references after the removed
block still rely on the existing thread_idx variable and that no logic dependent
on re-calling cute.arch.thread_idx() is lost.
In `@flashinfer/cute_dsl/attention/scheduler/persistent.py`:
- Around line 38-45: In __new_from_mlir_values__ update the zip over
[self.is_persistent, self.problem_shape_mbh] and self._values_pos to use
zip(..., strict=True) to fail loudly on length mismatches, and when returning
the FmhaStaticTileSchedulerParams instance forward the current object's ip
parameter (pass loc=self._loc, ip=self.ip) so the new instance receives ip as
well; this touches the __new_from_mlir_values__ method, the attributes
self.is_persistent, self.problem_shape_mbh, self._values_pos, and the
FmhaStaticTileSchedulerParams constructor call.
- Around line 148-158: The hardcoded assertion in __new_from_mlir_values__
(assert len(values) == 10) is fragile; change it to compute the expected MLIR
value count by summing the MLIR-value counts of the constituent objects
(self._params, self._current_work_linear_idx, self._blk_coord, self._grid_shape)
using whatever helper/attribute your cutlass layer exposes (e.g., a mlir value
count helper or by querying each object's MLIR representation), then compare
len(values) to that computed total and raise a ValueError with a descriptive
message if mismatched; update the slicing logic that builds new_params,
new_current_work_linear_idx, new_blk_coord, and new_grid_shape to use those
computed per-object counts instead of fixed indices so FmhaStaticTileScheduler
construction remains correct.
In `@flashinfer/cute_dsl/attention/tmem_layout.py`:
- Around line 35-49: Extract the literal 512 used for TMEM capacity into a
module-level constant (e.g., SM100_TMEM_CAPACITY_COLUMNS = 512) and replace the
local variable in TmemLayout.from_config so the function uses that constant
instead of a magic number; update the top of the module with the constant and
ensure TmemLayout.from_config (which takes AttentionConfig and reads
config.mma_tiler[0]) references the new constant for alloc_cols so other code
can reuse the hardware-capacity value.
In `@flashinfer/cute_dsl/attention/wrappers/batch_prefill.py`:
- Around line 393-396: The slice construction using zip(padding, shape_) can
silently ignore length mismatches; update the tuple comprehension that defines
slices (used to create torch_tensor from torch_tensor_full and assigned to
torch_tensor) to call zip with strict=True (i.e., zip(padding, shape_,
strict=True)) so any mismatch between padding and shape_ raises immediately and
makes debugging easier.
- Around line 129-157: The unpacked dummy tensors returned by
create_and_pad_tensor (q_ref, q_torch, k_ref, k_torch, v_ref, v_torch, o_torch)
are unused and should be prefixed with an underscore to indicate intentional
unused variables and silence linters; update the unpacking lines where
create_and_pad_tensor is called (for q_, k_, v_, and o_) to rename those
specific variables to _q_ref/_q_torch, _k_ref/_k_torch, _v_ref/_v_torch, and
_o_torch (or similar underscore-prefixed names) while keeping the used names
q_cute/k_cute/v_cute/o_cute unchanged.
- Around line 318-319: In batch_prefill.py replace the redundant explicit device
argument when creating the empty tensor so that out is created with
torch.empty_like(q) instead of torch.empty_like(q, device=q.device); locate the
assignment that sets out when out is None (the one referencing variables out and
q) and remove the device=q.device parameter to rely on torch.empty_like
inheriting q's device.
In `@tests/test_blackwell_fmha_attention.py`:
- Around line 1-13: The test module test_blackwell_fmha_attention.py is at the
tests/ root but belongs in the attention-specific kernel tests; move this suite
into a feature-specific tests subdirectory matching the kernel category (e.g.,
an attention/ or blackwell_fmha/ tests folder), update any relative imports
inside the module to the new location, and ensure all test callables in the file
are properly prefixed with test_ so pytest discovers them (check function names
and any parametrized fixtures used by functions in this module).
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: f4d60aef-cbff-41af-8cdd-f4914f9854d7
📒 Files selected for processing (24)
benchmarks/bench_blackwell_attention_cutedsl.pyflashinfer/cute_dsl/attention/__init__.pyflashinfer/cute_dsl/attention/collective_builder.pyflashinfer/cute_dsl/attention/config.pyflashinfer/cute_dsl/attention/fusion/__init__.pyflashinfer/cute_dsl/attention/fusion/mask.pyflashinfer/cute_dsl/attention/fusion/variant.pyflashinfer/cute_dsl/attention/mainloop_spec.pyflashinfer/cute_dsl/attention/pipeline_topology.pyflashinfer/cute_dsl/attention/prefill.pyflashinfer/cute_dsl/attention/roles/__init__.pyflashinfer/cute_dsl/attention/roles/correction.pyflashinfer/cute_dsl/attention/roles/epilogue.pyflashinfer/cute_dsl/attention/roles/loader_tma.pyflashinfer/cute_dsl/attention/roles/mma.pyflashinfer/cute_dsl/attention/roles/softmax.pyflashinfer/cute_dsl/attention/roles/softmax_math.pyflashinfer/cute_dsl/attention/scheduler/__init__.pyflashinfer/cute_dsl/attention/scheduler/persistent.pyflashinfer/cute_dsl/attention/tmem_layout.pyflashinfer/cute_dsl/attention/warp_schedule.pyflashinfer/cute_dsl/attention/wrappers/__init__.pyflashinfer/cute_dsl/attention/wrappers/batch_prefill.pytests/test_blackwell_fmha_attention.py
| if __name__ == "__main__": | ||
| bench_fmha_cutedsl(128, 512, 32, 128, True, torch.bfloat16) | ||
| bench_fmha_cutedsl(64, 1024, 32, 128, True, torch.bfloat16) | ||
| bench_fmha_cutedsl(32, 2048, 32, 128, True, torch.bfloat16) | ||
| bench_fmha_cutedsl(16, 4096, 32, 128, True, torch.bfloat16) | ||
| bench_fmha_cutedsl(8, 8192, 32, 128, True, torch.bfloat16) | ||
| bench_fmha_cutedsl(4, 16384, 32, 128, True, torch.bfloat16) | ||
| bench_fmha_cutedsl(2, 32768, 32, 128, True, torch.bfloat16) | ||
| bench_fmha_cutedsl(1, 65536, 32, 128, True, torch.bfloat16) |
There was a problem hiding this comment.
Skip the default sweep on unsupported GPUs.
The __main__ path unconditionally launches an SM100-only kernel. Running this script on another CUDA box will fail in JIT/launch instead of exiting cleanly with a clear message.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@benchmarks/bench_blackwell_attention_cutedsl.py` around lines 153 - 161, The
script currently unconditionally runs an SM100-only kernel in the __main__ block
(calls to bench_fmha_cutedsl), which will JIT/launch-fail on non-SM100 GPUs; add
a GPU capability check before running the default sweep: use
torch.cuda.is_available() and torch.cuda.get_device_capability() or
torch.cuda.get_device_properties(device).major/minor (or device name) to detect
whether the current GPU supports SM100, and if not, skip the default
bench_fmha_cutedsl(...) calls and exit or print a clear message; update the
__main__ section so the SM100-only sweep only runs when the capability check
passes.
| p_tmem_layout_staged = sm100_utils.make_smem_layout_a( | ||
| pv_tiled_mma, config.pv_mma_tiler, q_dtype, mainloop.acc_stage, | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Check how p_tmem_layout_staged is used and if q_dtype is the correct type
rg -n "p_tmem_layout" --type py flashinfer/cute_dsl/attention/Repository: flashinfer-ai/flashinfer
Length of output: 824
🏁 Script executed:
#!/bin/bash
# Get the full function signature and beginning to see all dtype parameters
sed -n '50,120p' flashinfer/cute_dsl/attention/collective_builder.pyRepository: flashinfer-ai/flashinfer
Length of output: 2667
🏁 Script executed:
#!/bin/bash
# Check sm100_utils.make_smem_layout_a to see what dtype parameter does
rg -n "def make_smem_layout_a" --type py flashinfer/cute_dsl/attention/Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
#!/bin/bash
# Check what dtypes are typically used - look for q_dtype, k_dtype, v_dtype, p_dtype
rg -n "q_dtype|k_dtype|v_dtype|p_dtype|intermediate_dtype" --type py flashinfer/cute_dsl/attention/collective_builder.py | head -30Repository: flashinfer-ai/flashinfer
Length of output: 900
P buffer TMEM layout should use v_dtype, not q_dtype.
The P buffer (attention weights) shares the same pv_tiled_mma with V, which is created with v_dtype. However, p_tmem_layout_staged is created with q_dtype instead, causing a dtype mismatch. Change line 97 to use v_dtype:
Before:
p_tmem_layout_staged = sm100_utils.make_smem_layout_a(
pv_tiled_mma, config.pv_mma_tiler, q_dtype, mainloop.acc_stage,
)After:
p_tmem_layout_staged = sm100_utils.make_smem_layout_a(
pv_tiled_mma, config.pv_mma_tiler, v_dtype, mainloop.acc_stage,
)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/cute_dsl/attention/collective_builder.py` around lines 96 - 98,
The p_tmem_layout_staged is being created with the wrong dtype (q_dtype) causing
a mismatch with pv_tiled_mma which was created for V; update the call to
sm100_utils.make_smem_layout_a in collective_builder so p_tmem_layout_staged
uses v_dtype instead of q_dtype (the call that takes pv_tiled_mma,
config.pv_mma_tiler, q_dtype, mainloop.acc_stage should pass v_dtype) to align
the P buffer TMEM layout with the V buffer.
| elif mask_type == MaskType.SLIDING_WINDOW_MASK: | ||
| first_q = blk_coord[0] * tile_shape[0] | ||
| last_q = (blk_coord[0] + 1) * tile_shape[0] - 1 | ||
| min_kv = cutlass.max(0, first_q - window_left) | ||
| max_kv = cutlass.min(seqlen_k - 1, last_q + window_left) | ||
| start_block = min_kv // tile_shape[1] | ||
| end_block = cute.ceil_div(max_kv + 1, tile_shape[1]) | ||
| result = end_block - start_block | ||
| return result |
There was a problem hiding this comment.
Sliding-window masking still ignores the Q/K length offset.
The causal path accounts for seqlen_k - seqlen_q, but the sliding-window branches compute trip counts, start blocks, and element masks from raw Q indices. For suffix-prefill or ragged batches where seqlen_q != seqlen_k, the window gets centered on the wrong KV region.
🛠️ One way to thread the existing offset through the sliding-window path
elif mask_type == MaskType.SLIDING_WINDOW_MASK:
- first_q = blk_coord[0] * tile_shape[0]
- last_q = (blk_coord[0] + 1) * tile_shape[0] - 1
+ q_offset = seqlen_k - seqlen_q
+ first_q = blk_coord[0] * tile_shape[0] + q_offset
+ last_q = (blk_coord[0] + 1) * tile_shape[0] - 1 + q_offset
min_kv = cutlass.max(0, first_q - window_left)
max_kv = cutlass.min(seqlen_k - 1, last_q + window_left)
start_block = min_kv // tile_shape[1]
end_block = cute.ceil_div(max_kv + 1, tile_shape[1])
result = end_block - start_block
@@
elif mask_type == MaskType.SLIDING_WINDOW_MASK:
trip_count = get_trip_count(
- mask_type, window_left, blk_coord, tile_shape, seqlen_k
+ mask_type, window_left, blk_coord, tile_shape, seqlen_k, seqlen_q
)
result = trip_count
@@
if cutlass.const_expr(mask_type == MaskType.SLIDING_WINDOW_MASK):
- first_q = blk_coord[0] * tile_shape[0]
+ q_offset = seqlen_k - seqlen_q
+ first_q = blk_coord[0] * tile_shape[0] + q_offset
min_kv = cutlass.max(0, first_q - window_left)
return min_kv // tile_shape[1]
@@
elif mask_type == MaskType.SLIDING_WINDOW_MASK:
for i in range(cute.size(acc_qk)):
pos = index_qk[i]
- if pos[1] - pos[0] > window_left or pos[0] - pos[1] > window_left or pos[1] >= seqlen_k:
+ q_pos = pos[0] + causal_offset
+ if pos[1] - q_pos > window_left or q_pos - pos[1] > window_left or pos[1] >= seqlen_k:
acc_qk[i] = -Float32.infAlso applies to: 85-88, 129-143, 146-170
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/cute_dsl/attention/fusion/mask.py` around lines 45 - 53,
Sliding-window masking currently centers the KV window on raw Q indices (see
MaskType.SLIDING_WINDOW_MASK handling using blk_coord, tile_shape, window_left,
seqlen_k) and ignores the Q/K length offset used by the causal path; compute
q_k_offset = seqlen_k - seqlen_q and add it to first_q and last_q (or otherwise
shift Q indices into KV space) before calculating min_kv, max_kv, start_block,
end_block, and any element masks; apply the same fix to the other sliding-window
blocks noted (around the other occurrences at the given ranges) so all
sliding-window computations use shifted Q indices into KV coordinate space.
| @cute.jit | ||
| def update_statistics(self, kv_tile_idx, qo_head_idx, m, d, scale): | ||
| log2_e = math.log2(math.exp(1.0)) | ||
| sink_raw = ( | ||
| self.params[qo_head_idx] * log2_e / scale | ||
| if kv_tile_idx == 0 | ||
| else -math.inf | ||
| ) | ||
| m_new = sink_raw if sink_raw > m else m | ||
| rescale = cute.arch.exp2((m - m_new) * scale) | ||
| d_new = cute.arch.exp2((sink_raw - m_new) * scale) + d * rescale | ||
| return m_new, d_new |
There was a problem hiding this comment.
🧩 Analysis chain
🌐 Web query:
CuTe DSL attention sink implementation softmax denominator
💡 Result:
No specific "attention sink implementation" in CuTe DSL was found. CuTe DSL (from NVIDIA CUTLASS) is used for high-performance FlashAttention kernels, including softmax computation via online softmax for numerical stability. The softmax denominator (row sums of exp(scores - row_max)) is computed using tiled reductions in shared memory across blocks/tiles of the attention matrix, leveraging CuTe's tensor layouts, TensorSSA for vectorized ops, and arch-specific reductions like cute.arch.warp_reduction_sum (as in RMSNorm examples, analogous for softmax rows). Official examples like CUTLASS's flash_attention_v2.py feature softmax_rescale_O for online softmax fusion with GEMM. Repos like luliyucoordinate/cute-flash-attention and HuyNguyen-hust/flash-attn-101 implement FlashAttention in CuTe, using standard online softmax without special "sink" handling—sinks are a training/model phenomenon, not a kernel feature. FlashAttention-3/4 in Dao-AILab/flash-attention use CuTe DSL for Hopper/Blackwell with optimized exp2/emulation and reductions for the denominator. Attention sinks refer to probability mass concentrating on fixed positions in softmax attention (e.g., arXiv:2603.11487), but kernel implementations compute general softmax denominators via parallel reductions, not sink-specific logic.
Citations:
- 1: https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/ampere/flash_attention_v2.py
- 2: https://veitner.bearblog.dev/simple-reduction-in-cutedsl/
- 3: https://github.com/luliyucoordinate/cute-flash-attention
- 4: https://github.com/HuyNguyen-hust/flash-attn-101
- 5: Rationale behind using ex2_emulation in FA4 Softmax implementation vs standard exp2 Dao-AILab/flash-attention#2081
- 6: https://arxiv.org/pdf/2312.11918
- 7: https://research.colfax-intl.com/a-users-guide-to-flexattention-in-flash-attention-cute-dsl/
🏁 Script executed:
find . -type f -name "variant.py" | grep -E "cute_dsl|attention"Repository: flashinfer-ai/flashinfer
Length of output: 117
🏁 Script executed:
if [ -f "flashinfer/cute_dsl/attention/fusion/variant.py" ]; then
wc -l flashinfer/cute_dsl/attention/fusion/variant.py
fiRepository: flashinfer-ai/flashinfer
Length of output: 119
🏁 Script executed:
if [ -f "flashinfer/cute_dsl/attention/fusion/variant.py" ]; then
sed -n '350,395p' flashinfer/cute_dsl/attention/fusion/variant.py
fiRepository: flashinfer-ai/flashinfer
Length of output: 1402
🏁 Script executed:
# Search for scale parameter definition and sink-related documentation
rg -n "scale\s*=" flashinfer/cute_dsl/attention/fusion/variant.py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 335
🏁 Script executed:
# Look for documentation about sink semantics
rg -n -B3 -A3 "sink" flashinfer/cute_dsl/attention/fusion/variant.py | head -40Repository: flashinfer-ai/flashinfer
Length of output: 1858
🏁 Script executed:
# Search for params initialization and documentation
rg -n -B5 "self.params" flashinfer/cute_dsl/attention/fusion/variant.py | head -50Repository: flashinfer-ai/flashinfer
Length of output: 2271
🏁 Script executed:
sed -n '340,410p' flashinfer/cute_dsl/attention/fusion/variant.pyRepository: flashinfer-ai/flashinfer
Length of output: 2239
🏁 Script executed:
# Search for AttentionWithSink usage and test cases
rg -n "AttentionWithSink" --type pyRepository: flashinfer-ai/flashinfer
Length of output: 839
🏁 Script executed:
# Look for documentation around the sink parameter
sed -n '340,365p' flashinfer/cute_dsl/attention/fusion/variant.pyRepository: flashinfer-ai/flashinfer
Length of output: 790
🏁 Script executed:
# Look at test cases to understand sink value semantics
sed -n '420,450p' tests/test_blackwell_fmha_attention.pyRepository: flashinfer-ai/flashinfer
Length of output: 1041
🏁 Script executed:
# Check more of the test context
sed -n '715,780p' tests/test_blackwell_fmha_attention.pyRepository: flashinfer-ai/flashinfer
Length of output: 2614
🏁 Script executed:
# Look for any documentation about sink value domain/units
rg -n -B5 -A5 "sink" flashinfer/cute_dsl/attention/fusion/variant.py | tail -80Repository: flashinfer-ai/flashinfer
Length of output: 2864
🏁 Script executed:
# Check if there's a reference paper or documentation file
find . -type f -name "*.md" -o -name "*.rst" | xargs grep -l "sink" 2>/dev/nullRepository: flashinfer-ai/flashinfer
Length of output: 52
🏁 Script executed:
# Find the reference implementation to understand sink semantics
rg -n "def attention_ref" tests/test_blackwell_fmha_attention.pyRepository: flashinfer-ai/flashinfer
Length of output: 89
🏁 Script executed:
# Look at the reference implementation
sed -n '1,120p' tests/test_blackwell_fmha_attention.pyRepository: flashinfer-ai/flashinfer
Length of output: 3643
🏁 Script executed:
# Search for attention_ref implementation across the repo
find . -type f -name "*.py" -exec grep -l "def attention_ref" {} \;Repository: flashinfer-ai/flashinfer
Length of output: 230
🏁 Script executed:
# Find the sink_softmax implementation
find . -type f -name "sink_attention_reference.py"Repository: flashinfer-ai/flashinfer
Length of output: 116
🏁 Script executed:
# Search for sink_softmax definition
rg -n "def sink_softmax" --type pyRepository: flashinfer-ai/flashinfer
Length of output: 149
🏁 Script executed:
cat tests/test_helpers/sink_attention_reference.pyRepository: flashinfer-ai/flashinfer
Length of output: 16156
🏁 Script executed:
# Look at how sink values are used in tests - what domain are they in?
sed -n '420,435p' tests/test_blackwell_fmha_attention.pyRepository: flashinfer-ai/flashinfer
Length of output: 710
🏁 Script executed:
# Check the documentation docstring more carefully
sed -n '337,356p' flashinfer/cute_dsl/attention/fusion/variant.pyRepository: flashinfer-ai/flashinfer
Length of output: 696
🏁 Script executed:
# Look for any comments about sink semantics in scale context
rg -n -B10 -A10 "log2_e\|log2.math" flashinfer/cute_dsl/attention/fusion/variant.py | head -60Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Search for how scale is set/passed to update_statistics
rg -n "update_statistics.*scale" --type py -B3 -A3 | head -60Repository: flashinfer-ai/flashinfer
Length of output: 2605
🏁 Script executed:
# Look for where scale is initialized/computed in the wrapper
rg -n "scale.*=" flashinfer/cute_dsl/attention/ --type py | grep -v test | head -30Repository: flashinfer-ai/flashinfer
Length of output: 3044
🏁 Script executed:
# Check if there's any comment about sink value domain in tests
rg -n -B5 -A5 "torch.randn.*sink" tests/ --type py | head -80Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Look at the softmax_math.py to understand how m and d work
sed -n '1,50p' flashinfer/cute_dsl/attention/roles/softmax_math.pyRepository: flashinfer-ai/flashinfer
Length of output: 1430
🏁 Script executed:
# Check what domain m and d are actually in
rg -n "m_new\|d_new" flashinfer/cute_dsl/attention/fusion/variant.py -B2 -A2Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Search for documentation about expected sink value domain/units
rg -n "sink" flashinfer/cute_dsl/attention/fusion/variant.py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 1018
🏁 Script executed:
# Let me check the documentation more carefully for scale parameter
sed -n '50,100p' flashinfer/cute_dsl/attention/fusion/variant.pyRepository: flashinfer-ai/flashinfer
Length of output: 1926
🏁 Script executed:
# Look at the complete context of the update_statistics docstring
sed -n '275,315p' flashinfer/cute_dsl/attention/fusion/variant.pyRepository: flashinfer-ai/flashinfer
Length of output: 1503
🏁 Script executed:
# Let me verify the mathematical simplification one more time
python3 << 'EOF'
import math
# Verify: if scale = log2(e) * sm_scale, then params * log2_e / scale = params / sm_scale
log2_e = math.log2(math.exp(1.0))
print(f"log2(e) = {log2_e}")
print(f"log2(e) ≈ 1.4426950408889634")
# For any sm_scale, verify the simplification
sm_scale = 1.0 / math.sqrt(128) # example: head_dim = 128
scale = log2_e * sm_scale
print(f"\nFor head_dim=128: sm_scale = {sm_scale}")
print(f"scale = log2(e) * sm_scale = {scale}")
# Test the simplification with a sample sink value
params = 2.5
sink_raw_original = params * log2_e / scale
sink_raw_simplified = params / sm_scale
print(f"\nFor params = {params}:")
print(f"sink_raw (original formula) = params * log2_e / scale = {sink_raw_original}")
print(f"sink_raw (simplified formula) = params / sm_scale = {sink_raw_simplified}")
print(f"Match: {abs(sink_raw_original - sink_raw_simplified) < 1e-10}")
EOFRepository: flashinfer-ai/flashinfer
Length of output: 401
🏁 Script executed:
# Check if sink value domain is documented in docstring or README
find . -type f \( -name "README.md" -o -name "*.rst" \) -exec grep -l "sink" {} \;Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Final check: Look at the complete example in the docstring
sed -n '131,150p' flashinfer/cute_dsl/attention/fusion/variant.pyRepository: flashinfer-ai/flashinfer
Length of output: 751
Clarify the expected domain of sink tensor values in the docstring.
The mathematical conversion sink_raw = params * log2_e / scale is correct and simplifies to params / sm_scale as stated. However, the class docstring does not document the domain in which sink values are expected. The reference implementation (sink_softmax in sink_attention_reference.py) concatenates sink values directly to logits, confirming they should be in the logit domain (raw QK dot-product units, unnormalized). Add this to the sink parameter documentation to clarify that users must provide sink values in the logit domain, not pre-converted to log2 scale.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/cute_dsl/attention/fusion/variant.py` around lines 367 - 378,
Update the class and relevant parameter docstrings to state that sink values are
expected in the logit domain (raw Q·K dot-product units, unnormalized), not
pre-scaled to log2; specifically mention this near the documentation for the
sink parameter(s) used by update_statistics and the self.params/sink_raw
conversion (which divides by scale/log2_e), and add a cross-reference to
sink_softmax in sink_attention_reference.py so callers know sinks are
concatenated to logits before any log2 scaling.
| @cute.jit | ||
| def score_mod(self, score, batch_idx, qo_idx, kv_idx, | ||
| qo_head_idx, kv_head_idx): | ||
| return self.cap * cute.arch.tanh(score * self.rcp_cap) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Search for tanh usage or definition in cutlass/cute
rg -n "tanh" --type py -g '!*test*' | head -30
# Check if there's documentation or other usage
rg -n "cute\.arch\.tanh" --type pyRepository: flashinfer-ai/flashinfer
Length of output: 3170
🏁 Script executed:
#!/bin/bash
# First, let's look at the file structure and imports in variant.py
head -50 flashinfer/cute_dsl/attention/fusion/variant.py | cat -n
# Check imports related to cute
rg -n "^import cute|^from cute" flashinfer/cute_dsl/attention/fusion/variant.pyRepository: flashinfer-ai/flashinfer
Length of output: 2442
🏁 Script executed:
#!/bin/bash
# Look for cute.arch definition or where it comes from
rg -n "cute\.arch" flashinfer/cute_dsl/attention/fusion/variant.py -B 2 -A 2 | head -50
# Check if cute is a local module or external dependency
find . -name "cute.py" -o -name "cute" -type d | grep -v __pycache__Repository: flashinfer-ai/flashinfer
Length of output: 890
🏁 Script executed:
#!/bin/bash
# Check for cute library in requirements or setup files
find . -name "requirements*.txt" -o -name "setup.py" -o -name "pyproject.toml" | xargs grep -l cute 2>/dev/null | head -10
# Look at the actual file content around the usages
sed -n '550,560p' flashinfer/cute_dsl/attention/fusion/variant.py | cat -n
sed -n '165,175p' flashinfer/cute_dsl/attention/fusion/variant.py | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 634
🏁 Script executed:
#!/bin/bash
# Look for the cute module or arch submodule definition
find . -name "*.py" -path "*/cute*" | head -20
# Search for arch module or its definition
rg -n "class arch|def tanh|def exp2|def rcp_approx" --type py | head -30Repository: flashinfer-ai/flashinfer
Length of output: 1477
🏁 Script executed:
#!/bin/bash
# Check if there are any bindings or stubs for cute
rg -n "cutlass" . --type py -l | head -20
# Look for any .pyi stub files or type hints
find . -name "*.pyi" | head -10
# Check requirements or dependencies
cat pyproject.toml 2>/dev/null | head -50
cat setup.py 2>/dev/null | head -50Repository: flashinfer-ai/flashinfer
Length of output: 2487
🏁 Script executed:
#!/bin/bash
# Search for tests that use cute.arch functions
rg -n "cute\.arch\.(tanh|exp2|rcp_approx)" --type py -l
# Check if there's error handling around these calls
rg -n "cute\.arch" flashinfer/cute_dsl/attention/fusion/variant.py -B 3 -A 3 | head -80Repository: flashinfer-ai/flashinfer
Length of output: 1476
🏁 Script executed:
#!/bin/bash
# Check what's actually exported from cutlass.cute
rg -n "from cutlass" flashinfer/cute_dsl/attention/fusion/variant.py -A 5 -B 5
# Look for how cute module is used in other files
rg -n "cute\.arch" flashinfer/fused_moe/cute_dsl/blackwell/utils.py -A 2 -B 2 | head -30
# Check if there are any imports or type stubs
rg -n "cutlass\.cute" . --type py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 2104
🏁 Script executed:
#!/bin/bash
# Search for any CUTLASS bindings or Python wrapping
find . -name "*cute*" -type f | grep -E "\.(py|pyi|so|pyd)" | head -20
# Look for documentation or examples of cute.arch usage
rg -n "cute\.arch\." . --type py -l | xargs head -100 | head -50Repository: flashinfer-ai/flashinfer
Length of output: 2173
🌐 Web query:
CUTLASS cute arch tanh exp2 rcp_approx Python bindings documentation
💡 Result:
CUTLASS provides Python bindings through its CuTe DSL (Domain-Specific Language), enabling high-performance CUDA kernel development in Python. The CuTe architecture module (cutlass.cute.arch) includes wrappers for NVVM-level MLIR operations that integrate with CuTe Python types. Key functions in cutlass.cute.arch relevant to the query: - rcp_approx(a: float | cutlass.cute.typing.Float32, *, loc=None, ip=None) → cutlass.cute.typing.Float32: Computes approximate reciprocal. - exp2(a: float | cutlass.cute.typing.Float32, *, loc=None, ip=None) → cutlass.cute.typing.Float32: Computes base-2 exponential. - No explicit tanh function found in the arch module; it may require composition of other operations or be available elsewhere in CuTe. Official documentation: - Primary source: https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/cute_dsl_api/cute_arch.html (cutlass.cute.arch API) - Installation: Use setup.sh from CUTLASS repo (e.g., ./cutlass/python/CuTeDSL/setup.sh --cu13 for CUDA 13.1) or pip install nvidia-cutlass-dsl[cu13]. - GitHub: https://github.com/NVIDIA/cutlass (v4.4.2 as of 2026-03-17). Example usage in a @cute.jit kernel: import cutlass.cute as cute @cute.jit def kernel(a: cute.Tensor): recip = cute.arch.rcp_approx(a[0]) exp = cute.arch.exp2(recip) CuTe is part of CUTLASS 4.x, supporting NVIDIA architectures from Ampere to Blackwell.
Citations:
- 1: https://docs.nvidia.com/cutlass/4.2.1/media/docs/pythonDSL/cute_dsl_api/cute_arch.html
- 2: https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/cute_dsl_api/cute_arch.html
- 3: https://github.com/NVIDIA/cutlass
- 4: https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/cute_dsl_api/cute.html
- 5: https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/quick_start.html
- 6: https://docs.nvidia.com/cutlass/4.4.0/genindex.html
🏁 Script executed:
#!/bin/bash
# Search for tanh in the entire codebase to see how it should be implemented
rg -n "tanh" flashinfer/cute_dsl --type py -A 3 -B 3
# Check if there's cute.math module with tanh
rg -n "cute\.math\." flashinfer/cute_dsl --type py | head -20
# Look for any definitions of tanh function
rg -n "def tanh" . --type pyRepository: flashinfer-ai/flashinfer
Length of output: 2012
🏁 Script executed:
#!/bin/bash
# Check if there are any test files that test SoftCappingAttention
rg -n "SoftCappingAttention" . --type py
# Look for uses of score_mod in tests
rg -n "score_mod" tests/ --type py -A 5 | head -40Repository: flashinfer-ai/flashinfer
Length of output: 973
🌐 Web query:
CUTLASS cute tanh hyperbolic tangent math operations
💡 Result:
CUTLASS and CuTe (CUTE) do not provide built-in dedicated support for hyperbolic tangent (tanh) math operations within their core GEMM kernels or tensor algorithms. Tanh is available in related NVIDIA libraries but requires custom implementation for fusion in CUTLASS epilogues or CuTe usage. Key findings: - CuTe tensor algorithms (cute/algorithm/) support copy, gemm, axpby, fill, clear, but no tanh or unary activations like tanh listed in official docs. - CUTLASS epilogues support fused activations via cutlass::epilogue::thread::LinearCombination + Activation functors (e.g., ReLU, GELU, SiLU/LeakyReLU in activation.h), but no built-in tanh functor. Custom tanh can be implemented by extending LinearCombinationGeneric or Epilogue Visitor Trees (EVT) for CUTLASS 3.x, as confirmed in GitHub issues where users request sigmoid/tanh fusion. - Separate NVIDIA libraries provide tanh: cuda.tile.tanh in cuTile Python (tile-level), nv_tanh in libdevice (scalar). These are not directly integrated into CUTLASS/CuTe GEMM epilogues. - For performance-critical use (e.g., fused GEMM + tanh), implement custom epilogue op calling CUDA math intrinsics like nv_tanhf (fast float version), or approximate tanh (e.g., K-TanH from research papers). To use: In CUTLASS 3.x, extend Sm90EVT or similar with custom compute node using tanh; in 2.x, subclass LinearCombination with custom activation functor calling tanh. No official CUTLASS 4.x (as of 2026-03) release notes mention tanh addition.
Citations:
- 1: https://docs.nvidia.com/cuda/cutile-python/generated/cuda.tile.tanh.html
- 2: https://docs.nvidia.com/cutlass/media/docs/cpp/functionality.html
- 3: https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/04_algorithms.html
- 4: https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/cute_dsl_api/cute.html
- 5: https://docs.nvidia.com/cutlass/latest/media/docs/cpp/gemm_api.html
- 6: [QST]question about cutlass epilogue customization NVIDIA/cutlass#1288
- 7: https://research.colfax-intl.com/epilogue_visitor_tree
- 8: https://ipd.graylab.jhu.edu/rfdiffusion2/cutlass-3.5.1/docs/structcutlass_1_1epilogue_1_1thread_1_1LinearCombinationRelu_3_01ElementOutput___00_01Count_00_00274a94522c46cd041d0b10d484e2ef3.html
- 9: [FEA] LinearCombinationSilu epilogue NVIDIA/cutlass#476
- 10: https://ipd.graylab.jhu.edu/rfdiffusion2/cutlass-3.5.1/docs/classcutlass_1_1epilogue_1_1thread_1_1LinearCombination.html
- 11: https://github.com/NVIDIA/cutlass/blob/master/include/cutlass/epilogue/thread/linear_combination_relu.h
- 12: https://ipd.graylab.jhu.edu/rfdiffusion2/cutlass-3.5.1/docs/structcutlass_1_1epilogue_1_1thread_1_1LinearCombinationClamp_1_1Params.html
- 13: https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/0x_gemm_tutorial.html
cute.arch.tanh is not available in CUTLASS CuTe — this will cause JIT compilation to fail.
SoftCappingAttention.score_mod uses cute.arch.tanh, but this function does not exist in the CuTe DSL. CUTLASS/CuTe provides cute.arch.rcp_approx and cute.arch.exp2, but not tanh. The kernel will fail at JIT compile time with an AttributeError.
Implement tanh using available operations (e.g., composition via exp2 and rcp_approx) or use an alternative approximation.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/cute_dsl/attention/fusion/variant.py` around lines 551 - 554,
SoftCappingAttention.score_mod calls non-existent cute.arch.tanh; replace it
with a local tanh approximation implemented using available primitives (e.g.,
cute.arch.exp2 and cute.arch.rcp_approx) or a cheap rational polynomial and call
that from score_mod. Add a helper function (e.g., _tanh_approx(x)) in the same
class or module and use it in SoftCappingAttention.score_mod (referencing
self.cap and self.rcp_cap as before), implementing tanh(x) via exp2 by computing
exp(-2*abs(x)) with exp2(-2*abs(x)/ln2) plus sign handling or by a stable
rational approximation (polynomial numerator/denominator) and ensure the helper
uses cute.jit-compatible operations only.
| When cluster_scale > 1, the all-thread side of | ||
| UMMA_ASYNC / ASYNC_UMMA pipelines multiplies its thread count by | ||
| cluster_scale. TMA_UMMA pipelines are unaffected (leader-only on both sides). | ||
| """ | ||
|
|
||
| name: str | ||
| pipeline_type: PipelineType | ||
| stages: int | ||
| producer_warp_ids: Tuple[int, ...] | ||
| consumer_warp_ids: Tuple[int, ...] | ||
| tx_count_key: str | None = None | ||
| cluster_scale: int = 1 |
There was a problem hiding this comment.
cluster_scale is currently a no-op.
The dataclass docs say this field changes participant counts, but create_pipelines() always instantiates the producer/consumer groups as if it were 1. Any topology that starts using cluster_scale > 1 will build barriers with the wrong arrive counts.
🛠️ Either honor the field or fail fast when it is set
prod_threads = edge.pipeline_type.producer_thread_count(
len(edge.producer_warp_ids), threads_per_warp
)
cons_threads = edge.pipeline_type.consumer_thread_count(
len(edge.consumer_warp_ids), threads_per_warp
)
+ if edge.cluster_scale != 1:
+ if edge.pipeline_type == PipelineType.UMMA_ASYNC:
+ cons_threads *= edge.cluster_scale
+ elif edge.pipeline_type == PipelineType.ASYNC_UMMA:
+ prod_threads *= edge.cluster_scale
+ else:
+ raise ValueError(
+ f"cluster_scale is unsupported for {edge.pipeline_type}"
+ )Also applies to: 127-147
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/cute_dsl/attention/pipeline_topology.py` around lines 68 - 79, The
Pipeline.dataclass field cluster_scale is ignored by create_pipelines(), causing
incorrect participant and barrier arrive counts; either (preferred) honor it by
multiplying the all-thread side's participant counts when constructing
producer/consumer groups and computing barrier arrive counts for PipelineType
values UMMA_ASYNC and ASYNC_UMMA (but leave TMA_UMMA unchanged), i.e., when
building groups from producer_warp_ids/consumer_warp_ids in create_pipelines()
multiply the thread counts by pipeline.cluster_scale and use that scaled value
when setting arrive counts for barriers/tx_count_key, or fail fast by adding a
check in create_pipelines() that raises a clear exception if
pipeline.cluster_scale != 1 so callers must handle scaling explicitly.
| # NOTE on CuTe DSL JIT limitations: | ||
| # - partition_output(): Returns tensor tuples — CuTe DSL JIT does not | ||
| # reliably handle returning tensors from @cute.jit methods. | ||
| # - store_tile(): SAFE — takes pre-sliced tensors as arguments, no | ||
| # runtime indexing or return values. Used in run() successfully. | ||
| # ========================================================================= | ||
|
|
||
| @cute.jit | ||
| def partition_output( | ||
| self, | ||
| tma_atom_o: cute.CopyAtom, | ||
| mO_qdl: cute.Tensor, | ||
| sO: cute.Tensor, | ||
| block_coord: tuple, | ||
| ): | ||
| """Partition output global tensor for TMA stores. Returns (tOsO, tOgO).""" | ||
| gO_qdl = cute.flat_divide( | ||
| mO_qdl, cute.select(self.pv_mma_tiler, mode=[0, 1]) | ||
| ) | ||
| gO = gO_qdl[None, None, None, 0, block_coord[2]] | ||
| tOsO, tOgO = cute.nvgpu.cpasync.tma_partition( | ||
| tma_atom_o, 0, cute.make_layout(1), | ||
| cute.group_modes(sO, 0, 2), | ||
| cute.group_modes(gO, 0, 2), | ||
| ) | ||
| return tOsO, tOgO |
There was a problem hiding this comment.
🧩 Analysis chain
🌐 Web query:
Does NVIDIA CUTLASS CuTe DSL reliably support returning tensor objects from @cute.jit functions or methods, or should helpers that return tensors stay non-jitted/inlined?
💡 Result:
No, NVIDIA CUTLASS CuTe DSL does not reliably support returning tensor objects from @cute.jit functions or methods. Helpers that return tensors should stay non-jitted/inlined.
Citations:
Remove @cute.jit from partition_output() or refactor to avoid returning tensors.
This method is decorated with @cute.jit but returns a tensor tuple, contradicting the documented limitation directly above it: "CuTe DSL JIT does not reliably handle returning tensors from @cute.jit methods." Since this is exported as a reusable primitive, it violates its own safety contract. Either remove the @cute.jit decorator or restructure to keep tensor returns outside JIT compilation per NVIDIA documentation.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/cute_dsl/attention/roles/epilogue.py` around lines 41 - 66,
partition_output is incorrectly decorated with `@cute.jit` while returning tensor
objects (tOsO, tOgO) which violates the CuTe JIT limitation; either remove the
`@cute.jit` decorator from partition_output so it runs as a normal Python method,
or refactor it to avoid returning tensors by (a) accepting preallocated output
containers/handles and writing into them, or (b) moving the
cute.nvgpu.cpasync.tma_partition call out of the `@cute.jit` function into a
non-jit wrapper (e.g., create partition_output_nonjit that calls
cute.nvgpu.cpasync.tma_partition and returns tensors or change partition_output
to populate passed-in tensor references); update references to partition_output
accordingly so no `@cute.jit` function returns tensors (symbols: partition_output,
tOsO, tOgO, tma_partition, tma_atom_o).
| @dataclass(frozen=True) | ||
| class WarpSchedule: | ||
| """Defines warp role assignment and register budgets for attention kernels. | ||
|
|
||
| Each field maps directly to C++ CUTLASS's KernelSchedule: | ||
| - Warp ID ranges for each role | ||
| - Register allocation per role (controls spill/occupancy tradeoff) | ||
| - Barrier IDs for CTA sync and TMEM allocation | ||
| """ | ||
|
|
||
| softmax0_warp_ids: Tuple[int, ...] = (0, 1, 2, 3) | ||
| softmax1_warp_ids: Tuple[int, ...] = (4, 5, 6, 7) | ||
| correction_warp_ids: Tuple[int, ...] = (8, 9, 10, 11) | ||
| mma_warp_id: int = 12 | ||
| load_warp_id: int = 13 | ||
| epilogue_warp_id: int = 14 | ||
| empty_warp_id: int = 15 | ||
|
|
||
| num_regs_softmax: int = 192 | ||
| num_regs_correction: int = 96 | ||
| num_regs_other: int = 32 | ||
| num_regs_empty: int = 24 | ||
|
|
||
| threads_per_warp: int = 32 | ||
| cta_sync_bar_id: int = 0 | ||
| tmem_alloc_sync_bar_id: int = 1 | ||
|
|
||
| @property | ||
| def all_warp_ids(self) -> Tuple[int, ...]: | ||
| return ( | ||
| *self.softmax0_warp_ids, | ||
| *self.softmax1_warp_ids, | ||
| *self.correction_warp_ids, | ||
| self.mma_warp_id, | ||
| self.load_warp_id, | ||
| self.epilogue_warp_id, | ||
| self.empty_warp_id, | ||
| ) | ||
|
|
||
| @property | ||
| def num_warps(self) -> int: | ||
| return len(self.all_warp_ids) | ||
|
|
||
| @property | ||
| def threads_per_cta(self) -> int: | ||
| return self.threads_per_warp * self.num_warps | ||
|
|
||
| @property | ||
| def num_warps_per_warpgroup(self) -> int: | ||
| return 4 | ||
|
|
||
| @property | ||
| def softmax_warpgroup_count(self) -> int: | ||
| total_softmax_warps = len(self.softmax0_warp_ids) + len(self.softmax1_warp_ids) | ||
| return total_softmax_warps // self.num_warps_per_warpgroup |
There was a problem hiding this comment.
Validate custom schedules before deriving CTA sizes.
num_warps, threads_per_cta, and softmax_warpgroup_count all assume the warp ids are unique, contiguous from 0, and that the softmax warps fill whole warpgroups. With a custom WarpSchedule, duplicate/gapped ids or a non-multiple-of-4 softmax set will silently produce the wrong CTA/barrier sizing.
🛠️ Suggested fail-fast validation
`@dataclass`(frozen=True)
class WarpSchedule:
@@
threads_per_warp: int = 32
cta_sync_bar_id: int = 0
tmem_alloc_sync_bar_id: int = 1
+
+ def __post_init__(self):
+ all_warp_ids = self.all_warp_ids
+ if len(set(all_warp_ids)) != len(all_warp_ids):
+ raise ValueError("warp ids must be unique across roles")
+ if tuple(sorted(all_warp_ids)) != tuple(range(len(all_warp_ids))):
+ raise ValueError("warp ids must form a contiguous range starting at 0")
+ total_softmax_warps = len(self.softmax0_warp_ids) + len(self.softmax1_warp_ids)
+ if total_softmax_warps % self.num_warps_per_warpgroup != 0:
+ raise ValueError("softmax warps must fill whole warpgroups")
+ if self.cta_sync_bar_id == self.tmem_alloc_sync_bar_id:
+ raise ValueError("barrier ids must be distinct")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/cute_dsl/attention/warp_schedule.py` around lines 17 - 71, Add a
fail-fast validation in WarpSchedule (implement in a __post_init__ method) that
verifies: 1) all_warp_ids (built from softmax0_warp_ids, softmax1_warp_ids,
correction_warp_ids, mma_warp_id, load_warp_id, epilogue_warp_id, empty_warp_id)
contain unique values and form a contiguous range starting at 0 up to
len(all_warp_ids)-1, and 2) the total number of softmax warps
(len(softmax0_warp_ids)+len(softmax1_warp_ids)) is divisible by
num_warps_per_warpgroup; on violation raise ValueError with a clear message
referencing the failing condition so consumers of num_warps, threads_per_cta,
and softmax_warpgroup_count cannot silently compute incorrect sizes.
| self._has_params = self._variant.extra_params is not None | ||
| if self._has_params: | ||
| ep = self._variant.extra_params.to(torch.float32).to(self._device) | ||
| if not ep.is_contiguous(): | ||
| raise ValueError( | ||
| f"AttentionVariant.extra_params must be contiguous, " | ||
| f"got strides {ep.stride()} for shape {ep.shape}. " | ||
| f"Call .contiguous() before returning from extra_params." | ||
| ) | ||
| self._params_torch = ep | ||
| params_cute = from_dlpack(ep, assumed_align=16) |
There was a problem hiding this comment.
Potential NameError if params_cute referenced when _has_params is False.
params_cute is defined only inside the if self._has_params: block (line 169), but it's referenced unconditionally at line 240. If _has_params is False, accessing params_cute.iterator would raise a NameError.
Looking more closely, line 240 uses a conditional expression params_cute.iterator if self._has_params else None, so the .iterator access is guarded. However, Python evaluates params_cute before the condition, which would still raise NameError if the variable is undefined.
Proposed fix
if self._has_params:
ep = self._variant.extra_params.to(torch.float32).to(self._device)
if not ep.is_contiguous():
raise ValueError(
f"AttentionVariant.extra_params must be contiguous, "
f"got strides {ep.stride()} for shape {ep.shape}. "
f"Call .contiguous() before returning from extra_params."
)
self._params_torch = ep
params_cute = from_dlpack(ep, assumed_align=16)
+ else:
+ params_cute = NoneThen at line 240:
- params_cute.iterator if self._has_params else None,
+ params_cute.iterator if params_cute is not None else None,Also applies to: 240-240
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/cute_dsl/attention/wrappers/batch_prefill.py` around lines 159 -
169, The NameError risk comes from params_cute being defined only inside the if
self._has_params block yet referenced later; fix by defining params_cute = None
before the if and only assigning it inside the block (where you call
from_dlpack) so later code can safely use the conditional expression
(params_cute.iterator if self._has_params else None); update references
involving self._has_params, _params_torch, and from_dlpack accordingly to rely
on the initialized params_cute variable.
Summary
Modular rewrite of the FMHA prefill kernel from #1549, refactored into composable building blocks with bug fixes and comprehensive tests.
Bug fixes vs #1549
apply_mask, incorrectget_trip_countandget_kv_start_block_idxfor symmetric windows, KV coordinate offset mismatch in softmax stageKey files
flashinfer/cute_dsl/attention/prefill.pyflashinfer/cute_dsl/attention/roles/flashinfer/cute_dsl/attention/fusion/flashinfer/cute_dsl/attention/wrappers/batch_prefill.pyflashinfer/cute_dsl/attention/config.pyflashinfer/cute_dsl/attention/pipeline_topology.pyflashinfer/cute_dsl/attention/collective_builder.pytests/test_blackwell_fmha_attention.pybenchmarks/bench_blackwell_attention_cutedsl.pyTest plan
cc: @yzh119
Summary by CodeRabbit
New Features
Tests