Skip to content

[CuTe DSL] Add modular FMHA prefill attention kernel#2805

Open
pgera wants to merge 14 commits intoflashinfer-ai:mainfrom
pgera:cutedsl-fmha-prefill
Open

[CuTe DSL] Add modular FMHA prefill attention kernel#2805
pgera wants to merge 14 commits intoflashinfer-ai:mainfrom
pgera:cutedsl-fmha-prefill

Conversation

@pgera
Copy link

@pgera pgera commented Mar 17, 2026

Summary

Modular rewrite of the FMHA prefill kernel from #1549, refactored into composable building blocks with bug fixes and comprehensive tests.

  • Architecture uses composable "roles" (Loader, MMA, Softmax, Correction, Epilogue) connected by declarative pipeline topologies, making it easy to swap warp schedules, fusion strategies, and masking
  • Supports bf16/fp16, head_dim 64/128, GQA/MHA, causal/sliding-window masks, custom logits transforms (e.g. sigmoid), output transforms, attention sink with custom M_D_update, and variable-length sequences

Bug fixes vs #1549

  • Causal mask boundary + PV accumulate: fixed off-by-one in causal mask and accumulation bug
  • Attention sink dtype mismatch: wrapper hardcoded fp16 for sink tensor regardless of input dtype
  • Attention sink M_D_update domain: corrected domain conversion and exp2 scaling in online softmax
  • Sliding window mask (3 issues): missing left-bound check in apply_mask, incorrect get_trip_count and get_kv_start_block_idx for symmetric windows, KV coordinate offset mismatch in softmax stage

Key files

Path Description
flashinfer/cute_dsl/attention/prefill.py Main FMHA prefill kernel
flashinfer/cute_dsl/attention/roles/ Warp role implementations (softmax, MMA, loader, correction, epilogue)
flashinfer/cute_dsl/attention/fusion/ Mask, logits transform, output transform
flashinfer/cute_dsl/attention/wrappers/batch_prefill.py PyTorch wrapper API
flashinfer/cute_dsl/attention/config.py Kernel configuration
flashinfer/cute_dsl/attention/pipeline_topology.py Declarative pipeline graph
flashinfer/cute_dsl/attention/collective_builder.py MMA atoms, TMA, SharedStorage factory
tests/test_blackwell_fmha_attention.py Test suite (8 test cases)
benchmarks/bench_blackwell_attention_cutedsl.py Benchmark script

Test plan

  • All 8 tests pass on SM100 (bf16/fp16, causal, sliding window, head_dim 64/128, GQA/MHA, sigmoid logits, attention sink, varlen)
  • Benchmark against existing FlashInfer attention kernels
  • Review pipeline topology and warp schedule for performance

cc: @yzh119

Summary by CodeRabbit

  • New Features

    • Added CuTe DSL-based fused multi-head attention kernels for Blackwell GPUs with support for batch prefill operations
    • Introduced attention variant customizations: standard softmax, sink attention, sigmoid transforms, ALiBi and RPE scoring, soft-capping
    • Added support for causal masking, sliding window masking, and variable-length sequences
    • Introduced comprehensive benchmarking tools for attention kernel performance profiling
  • Tests

    • Added extensive test suite validating attention operations across multiple configurations and masking strategies

pgera added 6 commits March 16, 2026 02:01
Add CuTe DSL-based attention implementation:
- flashinfer/cute_dsl/attention/ - Modular attention package with composable
  roles (loader, softmax, MMA, epilogue), fusion points (logits transform,
  mask, output transform), and schedulers
- flashinfer/cute_dsl/prefill.py - Batch prefill wrapper
- flashinfer/cute_dsl/mla.py - MLA decode wrapper
- flashinfer/cute_dsl/patch/pipeline.py - Pipeline patching utilities

Tests and benchmarks (named to avoid conflicts with existing cutlass tests):
- tests/test_blackwell_fmha_cutedsl.py - FMHA tests (prefill)
- tests/test_blackwell_fmha_attention.py - Modular attention package tests
- tests/test_blackwell_mla_attention.py - MLA attention tests
- tests/test_deepseek_mla_cutedsl.py - DeepSeek MLA tests
- benchmarks/bench_blackwell_attention_cutedsl.py - Attention benchmarks
- docs/cutedsl_fmha_architecture.md - Architecture documentation

Made-with: Cursor
- Delete flashinfer/cute_dsl/prefill.py and mla.py (replaced by
  the modular flashinfer/cute_dsl/attention/ package)
- Delete tests/test_blackwell_fmha_cutedsl.py and
  tests/test_deepseek_mla_cutedsl.py (replaced by
  test_blackwell_fmha_attention.py and test_blackwell_mla_attention.py)
- Revert benchmarks/bench_deepseek_mla.py to upstream version
- Split benchmarks into prefill and decode:
  bench_blackwell_attention_cutedsl.py (FMHA prefill)
  bench_blackwell_mla_cutedsl.py (MLA decode)

Made-with: Cursor
…rnels

Two kernel correctness bugs fixed:

1. PV1(end) accumulate flag: The final PV GEMM for stage 1 used hardcoded
   accumulate=True, causing stale TMEM data corruption when the KV loop
   didn't execute (kv_len <= tile_size with multi-Q-tile batches).
   Fix: use pv_whether_acc instead of True.

2. Causal mask trip count: get_masked_trip_count used ceil_div(M, N) which
   doesn't account for non-zero causal_offset shifting the diagonal across
   extra KV tiles. When kv_len != qo_len, some tiles needing masking were
   processed as unmasked, leaking unmasked scores into softmax.
   Fix: compute masked tile count from actual diagonal boundary positions.

Both fixes required threading seqlen_q through the mask functions and
passing causal_offset to apply_mask.

Test suite pruned to ~112 curated cases covering tile boundaries, GQA,
varlen, causal, output/logits transforms, and attention sink.

AI-assisted (Claude)

Made-with: Cursor
…ate domain conversion

The plan() method created a dummy sink tensor with hardcoded float16 dtype
for JIT compilation regardless of input dtype. When bfloat16 inputs were
used at runtime, the compiled kernel misinterpreted bf16 bits as fp16,
garbling sink values (causal row-0 error: 1.75 -> 0.004).

Also fix the sink_M_D_update test helper to properly convert the sink value
from scaled-logit space to raw-logit space by dividing by scale, and tighten
the sink test tolerance from atol=2.0 to atol=0.01.

AI-assisted (Claude)

Made-with: Cursor
…ve tests (AI-assisted)

Kernel fixes:
- Sliding window apply_mask: add missing left-bound check (|kv-q| > window)
  and seqlen_k bounds check
- Sliding window get_trip_count/get_kv_start_block_idx: compute correct
  symmetric window tile range instead of right-only approximation
- Softmax run(): add kv_start_offset to coordinate identity tensor so mask
  coordinates match actual KV positions loaded by the TMA loader

Test fixes:
- sink_M_D_update: add * scale to exp2 rescale terms for correctness
  (m is in RAW domain, exp2 needs domain conversion via * scale)
- Sink test: use SM_SCALE=1/sqrt(head_dim) instead of 1.0, which made the
  sink contribution negligible (~0) and the test vacuous

New test coverage:
- float16 dtype (3 shapes x 2 causal)
- Sliding window mask (4 window/shape combos)
- head_dim=64 (3 shapes x 2 causal)
- Variable-length + sigmoid logits transform (2 indptr patterns)
- Variable-length + attention sink (2 indptr patterns)
- Attention sink with MHA / num_kv_heads=32 (2 shapes x 2 causal)

All 118 tests pass, 18 skipped (qo>kv+causal), ~10 min runtime.

Made-with: Cursor
Strip out MLA decode kernel, config, warp schedule, roles, scheduler,
wrapper, benchmark, test, and architecture doc to keep this PR focused
on FMHA prefill only. Clean up MLA references in shared modules.

AI-assisted

Made-with: Cursor
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 17, 2026

📝 Walkthrough

Walkthrough

This PR introduces a comprehensive CuTe DSL-based fused multi-head attention (FMHA) implementation for Blackwell GPUs. It adds modular attention kernel infrastructure with configurable variants, masking, pipeline topology, and a PyTorch wrapper for batch prefill operations, complemented by benchmarking and extensive testing.

Changes

Cohort / File(s) Summary
Benchmark Infrastructure
benchmarks/bench_blackwell_attention_cutedsl.py
Added benchmark script with bench_fmha_blackwell() and bench_fmha_cutedsl() functions to measure GPU runtime, throughput (TFLOPs/s), and bandwidth (GB/s) for batched prefill attention across varying batch sizes and sequence lengths.
Attention Package Core
flashinfer/cute_dsl/attention/__init__.py, config.py, tmem_layout.py, warp_schedule.py
Introduced package-level re-exports, configuration dataclasses (AttentionConfig, HeadMapping, TileBounds), and foundational infrastructure (TmemLayout for TMEM buffer allocation, WarpSchedule for warp role assignments).
Fusion Variants & Masking
flashinfer/cute_dsl/attention/fusion/...
Added modular attention variant system (AttentionVariant base class with concrete implementations: StandardAttention, AttentionWithSink, SigmoidAttention, ALiBiAttention, RPEAttention, SoftCappingAttention) and mask infrastructure with MaskType enum and JIT-based masking helpers (apply_mask, get_trip_count, etc.).
Pipeline & Mainloop Topology
flashinfer/cute_dsl/attention/pipeline_topology.py, mainloop_spec.py
Defined declarative pipeline-graph representation (PipelineType, PipelineEdge, PipelineTopology) and mainloop composition (MainloopSpec) for wiring kernel components with configurable stage counts and pipeline edges.
Kernel Roles (Execution Stages)
flashinfer/cute_dsl/attention/roles/...
Implemented specialized orchestration classes for each computation stage: LoaderRole (TMA-based data loading), MmaRole (QK/PV GEMMs with TMEM management), SoftmaxRole (online softmax with masking/fusion support), CorrectionRole (post-softmax correction), EpilogueRole (output writing), plus shared softmax math utilities.
Static Tile Scheduler
flashinfer/cute_dsl/attention/scheduler/...
Added persistent/static FMHA tile scheduler (FmhaStaticTileScheduler, FmhaStaticTileSchedulerParams) with MLIR serialization support for mapping CTAs to problem tiles.
Kernel Implementation
flashinfer/cute_dsl/attention/collective_builder.py, prefill.py
Introduced build_fmha_launch_params() for deriving CUTLASS/CuTe launch parameters (TMA atoms, cluster shapes, shared storage) and BlackwellFusedMultiHeadAttentionForward class orchestrating all kernel roles, warp specialization, and barrier synchronization.
PyTorch Wrapper
flashinfer/cute_dsl/attention/wrappers/batch_prefill.py
Added BatchPrefillCuteDSLWrapper providing PyTorch integration: plan() validates inputs/compiles kernel with attention config/fusion, and run() launches compiled kernel with DLPack tensor conversion; includes helper functions qkv_torch_2_cute() and create_and_pad_tensor() for tensor padding/layout.
Test Suite
tests/test_blackwell_fmha_attention.py
Comprehensive pytest module with reference attention implementations (softmax, causal, sink, sigmoid, ALiBi, RPE, sliding-window variants) and parametrized tests validating BatchPrefillCuteDSLWrapper across diverse configurations (batch sizes, sequence lengths, head dimensions, dtypes, variants).

Suggested labels

ready, run-ci

Suggested reviewers

  • yzh119
  • aleozlx
  • cyx-6
  • jimmyzho
  • jiahanc
  • nvmbreughe
  • kahyunnam
  • bkryu

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

🐰 A kernel so grand, CuTe DSL in hand,
With softmax and sinks, and attention so planned,
Pipeline roles dance in a Blackwell parade,
From loader to epilogue, fusion-made trade!
Attention perfected with warps in the sky. ✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 42.55% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title '[CuTe DSL] Add modular FMHA prefill attention kernel' accurately describes the main addition: a new modular prefill attention kernel composed of reusable building blocks in the CuTe DSL framework.
Description check ✅ Passed The PR description covers objectives, architecture, features, bug fixes, key files, and test plan, addressing all major template sections. While pre-commit checks and some test confirmations are marked incomplete, the core description content is comprehensive and well-structured.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Tip

CodeRabbit can generate a title for your PR based on the changes with custom instructions.

Set the reviews.auto_title_instructions setting to generate a title for your PR based on the changes in the PR with custom instructions.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a refactored and improved FMHA prefill kernel using a modular design. The new implementation addresses several bugs present in the previous version and offers a more flexible and testable architecture. By composing roles and using declarative pipeline topologies, the kernel supports a wide range of configurations and customizations, making it suitable for various attention mechanisms.

Highlights

  • Modular FMHA Prefill Kernel: This PR introduces a modular rewrite of the FMHA prefill kernel, enhancing composability and flexibility.
  • Bug Fixes: Addresses several bugs in the previous implementation, including causal mask boundary issues, dtype mismatches, and sliding window mask errors.
  • Comprehensive Testing: Includes a comprehensive test suite with 8 test cases covering various configurations (bf16/fp16, causal, sliding window, head_dim 64/128, GQA/MHA, sigmoid logits, attention sink, varlen).
  • Composable Architecture: The architecture is designed with composable "roles" connected by declarative pipeline topologies, facilitating easy swapping of warp schedules and fusion strategies.
  • Extensive Support: Supports bf16/fp16, head_dim 64/128, GQA/MHA, causal/sliding-window masks, custom logits transforms, output transforms, attention sink with custom M_D_update, and variable-length sequences.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Activity
  • The PR introduces a new modular FMHA prefill kernel.
  • It includes bug fixes for causal mask boundary, attention sink dtype mismatch, attention sink M_D_update domain, and sliding window mask issues.
  • The PR adds a comprehensive test suite with 8 test cases.
  • It refactors the kernel into composable building blocks with declarative pipeline topologies.
  • The PR supports various configurations such as bf16/fp16, head_dim 64/128, GQA/MHA, causal/sliding-window masks, custom logits transforms, output transforms, attention sink, and variable-length sequences.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant and well-structured modular rewrite of the FMHA prefill kernel using CuTe DSL. The new architecture, based on composable roles and declarative pipeline topologies, is a major improvement for maintainability and extensibility. The comprehensive test suite covers a wide range of configurations, ensuring the robustness of the new implementation. My review includes a few suggestions for code cleanup, improving maintainability by adding documentation, and addressing potential issues like unused code and a missing JIT decorator.

Comment on lines +25 to +52
def get_trip_count(
mask_type: MaskType,
window_left: int,
blk_coord: cute.Coord,
tile_shape: cute.Shape,
seqlen_k: Int32,
seqlen_q: Int32 = 0,
) -> Int32:
"""Number of KV tile blocks to process for this Q tile."""
result = 0
if mask_type == MaskType.NO_MASK or mask_type == MaskType.RESIDUAL_MASK:
result = cute.ceil_div(seqlen_k, tile_shape[1])
elif mask_type == MaskType.CAUSAL_MASK:
max_blocks_k = cute.ceil_div(seqlen_k, tile_shape[1])
causal_offset = seqlen_k - seqlen_q
max_blocks_q = cute.ceil_div(
(blk_coord[0] + 1) * tile_shape[0] + causal_offset, tile_shape[1]
)
result = cutlass.min(max_blocks_k, max_blocks_q)
elif mask_type == MaskType.SLIDING_WINDOW_MASK:
first_q = blk_coord[0] * tile_shape[0]
last_q = (blk_coord[0] + 1) * tile_shape[0] - 1
min_kv = cutlass.max(0, first_q - window_left)
max_kv = cutlass.min(seqlen_k - 1, last_q + window_left)
start_block = min_kv // tile_shape[1]
end_block = cute.ceil_div(max_kv + 1, tile_shape[1])
result = end_block - start_block
return result
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function get_trip_count is called from other JIT-compiled functions (e.g., in loader_tma.py, mma.py, correction.py, and indirectly from softmax.py) but it is not decorated with @cute.jit. While this might be handled by the compiler through inlining, it's better to be explicit to ensure correctness and improve maintainability. Please add the @cute.jit decorator to this function.

@cute.jit
def get_trip_count(
    mask_type: MaskType,
    window_left: int,
    blk_coord: cute.Coord,
    tile_shape: cute.Shape,
    seqlen_k: Int32,
    seqlen_q: Int32 = 0,
) -> Int32:
    """Number of KV tile blocks to process for this Q tile."""
    result = 0
    if mask_type == MaskType.NO_MASK or mask_type == MaskType.RESIDUAL_MASK:
        result = cute.ceil_div(seqlen_k, tile_shape[1])
    elif mask_type == MaskType.CAUSAL_MASK:
        max_blocks_k = cute.ceil_div(seqlen_k, tile_shape[1])
        causal_offset = seqlen_k - seqlen_q
        max_blocks_q = cute.ceil_div(
            (blk_coord[0] + 1) * tile_shape[0] + causal_offset, tile_shape[1]
        )
        result = cutlass.min(max_blocks_k, max_blocks_q)
    elif mask_type == MaskType.SLIDING_WINDOW_MASK:
        first_q = blk_coord[0] * tile_shape[0]
        last_q = (blk_coord[0] + 1) * tile_shape[0] - 1
        min_kv = cutlass.max(0, first_q - window_left)
        max_kv = cutlass.min(seqlen_k - 1, last_q + window_left)
        start_block = min_kv // tile_shape[1]
        end_block = cute.ceil_div(max_kv + 1, tile_shape[1])
        result = end_block - start_block
    return result

Comment on lines +1 to +101
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

"""Shared TMEM utilities for compute roles.

Provides tmem_load_partition() — partitions TMEM output accumulator for
load/store by the rescale and epilogue roles.
"""

from types import SimpleNamespace

import cutlass
import cutlass.cute as cute
import cutlass.cute.nvgpu.tcgen05 as tcgen05


@cute.jit
def tmem_load_partition(
tmem_ptr: cutlass.Int32,
tmem_o_offset: int,
acc_dtype: cutlass.Constexpr,
mma_pv_tiler: cutlass.Constexpr,
cluster_shape_mnk: cutlass.Constexpr,
warps_in_n: int,
num_compute_warps: int,
threads_per_warp: int,
common_params: SimpleNamespace,
tiled_mma_pv: cute.TiledMma,
iter_n: int,
) -> tuple[
cute.TiledMma,
cute.TiledMma,
cute.TiledMma,
cute.TiledMma,
cute.TiledMma,
cute.TiledMma,
]:
tOtO_shape = tiled_mma_pv.partition_shape_C(
cute.select(mma_pv_tiler, mode=[0, 1])
)
tOtO = tiled_mma_pv.make_fragment_C(tOtO_shape)
tOtO_layout = cute.append(
tOtO.layout,
cute.make_layout(
common_params.L // mma_pv_tiler[1],
stride=mma_pv_tiler[1] // warps_in_n,
),
)
tOtO = cute.make_tensor(tmem_ptr + tmem_o_offset, tOtO_layout)
tOtO = tOtO[None, None, None, iter_n]

tAcc = tOtO[(None, None), 0, 0]

tmem_load_atom = cute.make_copy_atom(
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), acc_dtype
)
tmem_load_tiled_copy = tcgen05.make_tmem_copy(tmem_load_atom, tAcc)
tmem_load_thr_copy = tmem_load_tiled_copy.get_slice(
common_params.tidx % (num_compute_warps * threads_per_warp)
)

cta_pv_tiler = (
mma_pv_tiler[0] // cluster_shape_mnk[0],
mma_pv_tiler[1],
mma_pv_tiler[2],
)
cta_pv_tiler_mn = cute.select(cta_pv_tiler, mode=[0, 1])

gO = None
if cutlass.const_expr(common_params.mAccO is not None):
gO = cute.local_tile(
common_params.mAccO[None, common_params.blk_coord[3], None, None],
cta_pv_tiler_mn,
(common_params.blk_coord[0], iter_n, common_params.blk_coord[2]),
)
cO = cute.local_tile(
cute.make_identity_tensor(
common_params.mAccO[
None, common_params.blk_coord[3], None, None
].shape
),
cta_pv_tiler_mn,
(common_params.blk_coord[0], iter_n, common_params.blk_coord[2]),
)
else:
gO = cute.local_tile(
common_params.mO,
cta_pv_tiler_mn,
(common_params.blk_coord[0], iter_n, common_params.blk_coord[2]),
)
cO = cute.local_tile(
cute.make_identity_tensor(common_params.mO.shape),
cta_pv_tiler_mn,
(common_params.blk_coord[0], iter_n, common_params.blk_coord[2]),
)

tTR_tAcc = tmem_load_thr_copy.partition_S(tAcc)
tTR_gO = tmem_load_thr_copy.partition_D(gO)
tTR_cO = tmem_load_thr_copy.partition_D(cO)
tTR_rAcc = cute.make_fragment_like(tTR_gO, acc_dtype)
return tmem_load_tiled_copy, tAcc, tTR_tAcc, tTR_gO, tTR_cO, tTR_rAcc
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This file tmem_utils.py and the function tmem_load_partition within it do not appear to be used anywhere in the codebase. Unused code can lead to confusion and maintenance overhead. Please either integrate it into the kernel or remove it if it's obsolete.

Comment on lines +171 to +177
if s_k.shape[0] > 1:
for i in range(len(s_k)):
if s_k[i] % self._mma_tiler_mn[1] != 0:
self._mask_type = MaskType.RESIDUAL_MASK
else:
if s_k % self._mma_tiler_mn[1] != 0:
self._mask_type = MaskType.RESIDUAL_MASK
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic to determine if RESIDUAL_MASK is needed can be simplified. The current implementation iterates over the s_k tensor and has a branch for s_k.shape[0] > 1 which is always taken since s_k is derived from kv_indptr and will be a 1D tensor. You can use torch.any for a more concise and efficient check.

            if torch.any(s_k % self._mma_tiler_mn[1] != 0):
                self._mask_type = MaskType.RESIDUAL_MASK

@@ -0,0 +1,419 @@
from typing import Optional, Type, Tuple
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This file provides a custom implementation of pipeline participants, which appears to be a patch on top of cutlass.pipeline. However, it's missing a file-level docstring explaining why this custom implementation is necessary and what it changes compared to the original. Adding a docstring would greatly improve maintainability and make it easier for other developers to understand the purpose of this module.

def sink_M_D_update(params, kv_tile_idx, qo_head_idx, m, d, scale):
# m is in the RAW (unscaled) domain; convert sink from scaled-logit to RAW
log2_e = math.log2(math.exp(1.0))
sink_raw = params.sink[qo_head_idx] * log2_e / scale if (kv_tile_idx == 0 and qo_head_idx < NUM_QO_HEADS) else -math.inf
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The condition qo_head_idx < NUM_QO_HEADS is redundant because qo_head_idx is a grid coordinate over the heads dimension, which is sized to NUM_QO_HEADS. Therefore, qo_head_idx will always be less than NUM_QO_HEADS. You can simplify the expression.

Suggested change
sink_raw = params.sink[qo_head_idx] * log2_e / scale if (kv_tile_idx == 0 and qo_head_idx < NUM_QO_HEADS) else -math.inf
sink_raw = params.sink[qo_head_idx] * log2_e / scale if kv_tile_idx == 0 else -math.inf

@cute.jit
def sink_M_D_update(params, kv_tile_idx, qo_head_idx, m, d, scale):
log2_e = math.log2(math.exp(1.0))
sink_raw = params.sink[qo_head_idx] * log2_e / scale if (kv_tile_idx == 0 and qo_head_idx < NUM_QO_HEADS) else -math.inf
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The condition qo_head_idx < NUM_QO_HEADS is redundant because qo_head_idx is a grid coordinate over the heads dimension, which is sized to NUM_QO_HEADS. Therefore, qo_head_idx will always be less than NUM_QO_HEADS. You can simplify the expression.

Suggested change
sink_raw = params.sink[qo_head_idx] * log2_e / scale if (kv_tile_idx == 0 and qo_head_idx < NUM_QO_HEADS) else -math.inf
sink_raw = params.sink[qo_head_idx] * log2_e / scale if kv_tile_idx == 0 else -math.inf

pgera added 8 commits March 18, 2026 15:28
The PipelineProducer/PipelineConsumer wrappers are now available in
cutlass.pipeline (nvidia-cutlass-dsl >= 4.3). Use them directly
instead of maintaining a local copy. Pipeline creation uses
defer_sync=True since the kernel handles barrier init/sync separately.

Verified: no perf regression (< 1% noise), all 118 tests pass.

AI-assisted

Made-with: Cursor
…ed tmem_utils.py

- Add missing @cute.jit decorator to get_trip_count for consistency
  with all other functions in mask.py
- Remove tmem_utils.py which was MLA-specific dead code after MLA removal

AI-assisted

Made-with: Cursor
…AI-assisted)

- Use torch.any() for concise RESIDUAL_MASK determination in batch_prefill.py
- Remove always-true qo_head_idx < NUM_QO_HEADS condition in sink_M_D_update tests

Made-with: Cursor
…d hook (AI-assisted)

Replace the dual sink_iter + variant_data_iter kernel parameters with a
single params_iter path. Variants now expose runtime tensor data via an
extra_params property; the kernel binds it to self.params before calling
any JIT method.

Key changes:
- AttentionVariant: remove use_attention_sink/variant_data_tensor, add
  extra_params property and score_mod hook
- AttentionWithSink: take sink tensor in constructor instead of run()
- AttentionFusion: replace use_attention_sink + has_variant_data with
  has_params + params_shape + params_strides
- prefill.py/softmax.py: merge sink + variant_data into single params
- batch_prefill.py: remove sink from run(), add contiguity assertion
- Fix CuTe column-major vs PyTorch row-major layout mismatch for N-D
  params by deriving explicit strides from the PyTorch tensor
- Add ALiBiAttention (1-D params), RPEAttention (2-D params),
  SoftCappingAttention (compile-time only) variant classes
- Add ALiBi and RPE tests with tight tolerances

Made-with: Cursor
…assisted)

- Add can_implement() to AttentionConfig for early validation of MMA tile
  size, head_dim divisibility, and GQA repeat count
- Add SMEM capacity check in prefill kernel to catch head_dim overruns
  with a clear error instead of a cryptic CUDA runtime error
- Add _validate_run_inputs() to BatchPrefillCuteDSLWrapper for runtime
  dtype/device/shape consistency checks
- Change MainloopSpec.resolve() to return a new object instead of
  mutating in place, avoiding subtle state bugs
- Clarify docstrings for decode-reserved fields and pipeline ordering

Made-with: Cursor
@pgera pgera marked this pull request as ready for review March 20, 2026 20:03
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 9

🧹 Nitpick comments (13)
flashinfer/cute_dsl/attention/tmem_layout.py (1)

35-49: Consider extracting SM100_TMEM_CAPACITY_COLUMNS as a module-level constant.

The SM100 TMEM capacity is a hardware characteristic that may be referenced elsewhere. Extracting it improves discoverability and avoids magic numbers.

Proposed refactor
+SM100_TMEM_CAPACITY_COLUMNS = 512
+
+
 `@dataclass`(frozen=True)
 class TmemLayout:
     ...
 
     `@staticmethod`
     def from_config(config: AttentionConfig) -> TmemLayout:
         tile_m = config.mma_tiler[0]
-        SM100_TMEM_CAPACITY_COLUMNS = 512
         return TmemLayout(
             ...
             alloc_cols=SM100_TMEM_CAPACITY_COLUMNS,
         )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/cute_dsl/attention/tmem_layout.py` around lines 35 - 49, Extract
the literal 512 used for TMEM capacity into a module-level constant (e.g.,
SM100_TMEM_CAPACITY_COLUMNS = 512) and replace the local variable in
TmemLayout.from_config so the function uses that constant instead of a magic
number; update the top of the module with the constant and ensure
TmemLayout.from_config (which takes AttentionConfig and reads
config.mma_tiler[0]) references the new constant for alloc_cols so other code
can reuse the hardware-capacity value.
flashinfer/cute_dsl/attention/scheduler/persistent.py (2)

38-45: Add strict=True to zip() for safer MLIR value reconstruction.

In __new_from_mlir_values__, the zip() call iterates over [self.is_persistent, self.problem_shape_mbh] and self._values_pos. If these lists have mismatched lengths (e.g., due to a maintenance error), zip() will silently truncate, potentially causing subtle bugs during MLIR reconstruction.

Also, the ip parameter is not forwarded to the new FmhaStaticTileSchedulerParams instance on line 45.

Proposed fix
     def __new_from_mlir_values__(self, values):
         obj_list = []
         for obj, n_items in zip(
-            [self.is_persistent, self.problem_shape_mbh], self._values_pos
+            [self.is_persistent, self.problem_shape_mbh], self._values_pos, strict=True
         ):
             obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
             values = values[n_items:]
-        return FmhaStaticTileSchedulerParams(*(tuple(obj_list)), loc=self._loc)
+        return FmhaStaticTileSchedulerParams(*(tuple(obj_list)), loc=self._loc, ip=self._ip)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/cute_dsl/attention/scheduler/persistent.py` around lines 38 - 45,
In __new_from_mlir_values__ update the zip over [self.is_persistent,
self.problem_shape_mbh] and self._values_pos to use zip(..., strict=True) to
fail loudly on length mismatches, and when returning the
FmhaStaticTileSchedulerParams instance forward the current object's ip parameter
(pass loc=self._loc, ip=self.ip) so the new instance receives ip as well; this
touches the __new_from_mlir_values__ method, the attributes self.is_persistent,
self.problem_shape_mbh, self._values_pos, and the FmhaStaticTileSchedulerParams
constructor call.

148-158: Hardcoded MLIR value count is fragile.

The assertion assert len(values) == 10 couples the implementation to a specific MLIR representation. If any constituent object's MLIR value count changes, this will fail without a clear message.

Consider deriving the expected count dynamically or providing a descriptive error message.

Proposed improvement
     def __new_from_mlir_values__(self, values):
-        assert len(values) == 10
+        expected = 3 + 1 + 3 + 3  # params(3) + work_idx(1) + blk_coord(3) + grid_shape(3)
+        assert len(values) == expected, f"Expected {expected} MLIR values, got {len(values)}"
         new_params = cutlass.new_from_mlir_values(self._params, values[0:3])
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/cute_dsl/attention/scheduler/persistent.py` around lines 148 -
158, The hardcoded assertion in __new_from_mlir_values__ (assert len(values) ==
10) is fragile; change it to compute the expected MLIR value count by summing
the MLIR-value counts of the constituent objects (self._params,
self._current_work_linear_idx, self._blk_coord, self._grid_shape) using whatever
helper/attribute your cutlass layer exposes (e.g., a mlir value count helper or
by querying each object's MLIR representation), then compare len(values) to that
computed total and raise a ValueError with a descriptive message if mismatched;
update the slicing logic that builds new_params, new_current_work_linear_idx,
new_blk_coord, and new_grid_shape to use those computed per-object counts
instead of fixed indices so FmhaStaticTileScheduler construction remains
correct.
flashinfer/cute_dsl/attention/collective_builder.py (1)

163-186: Consider using a typed dataclass instead of SimpleNamespace for better IDE support.

The returned SimpleNamespace contains 20+ fields. A typed dataclass or NamedTuple would provide autocompletion and type checking for consumers.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/cute_dsl/attention/collective_builder.py` around lines 163 - 186,
Replace the anonymous SimpleNamespace return with a typed dataclass (e.g.,
define `@dataclass` class AttentionCollective or AttentionCollectiveConfig) that
declares typed fields for each symbol currently passed (qk_tiled_mma,
pv_tiled_mma, tma_atom_q, tma_tensor_q, tma_atom_k, tma_tensor_k, tma_atom_v,
tma_tensor_v, tma_atom_o, tma_tensor_o, q_smem_layout_staged,
k_smem_layout_staged, p_tmem_layout_staged, v_smem_layout_staged,
o_smem_layout_staged, SharedStorage, tma_copy_q_bytes, tma_copy_kv_bytes,
cluster_shape_mnk, cluster_layout_vmnk, epi_tile, o_layout), add appropriate
type hints (use typing.Any or more specific types if known), import dataclasses
and typing, instantiate and return that dataclass instead of SimpleNamespace,
and update any consumers to accept the new dataclass type for improved IDE
autocompletion and type checking.
benchmarks/bench_blackwell_attention_cutedsl.py (1)

7-8: Use the public flashinfer.testing benchmark helper.

This benchmark already relies on the standard timing helper, but it pulls it from flashinfer.testing.utils, which couples the script to a private module path.

♻️ Suggested change
-from flashinfer.testing.utils import bench_gpu_time
+from flashinfer.testing import bench_gpu_time

Based on learnings Use flashinfer.testing.bench_gpu_time() for benchmarking kernels, preferring CUPTI timing with auto-fallback to CUDA events.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_blackwell_attention_cutedsl.py` around lines 7 - 8, The
benchmark imports bench_gpu_time from a private path (flashinfer.testing.utils);
update the import to use the public helper by replacing references to
flashinfer.testing.utils with the public module flashinfer.testing and import
bench_gpu_time from flashinfer.testing (i.e., use
flashinfer.testing.bench_gpu_time) so the benchmark relies on the supported
public API rather than a private module.
tests/test_blackwell_fmha_attention.py (1)

1-13: Please move this suite under a feature-specific tests subdirectory.

This is kernel-specific CuTe DSL attention coverage, but the new module sits at tests/ root. Putting it under a matching subdirectory keeps the test surface organized with the rest of the kernel-category suites.

As per coding guidelines tests/**/*.py: Prefix test functions with test_ and structure tests by feature in tests/ subdirectories matching kernel categories.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/test_blackwell_fmha_attention.py` around lines 1 - 13, The test module
test_blackwell_fmha_attention.py is at the tests/ root but belongs in the
attention-specific kernel tests; move this suite into a feature-specific tests
subdirectory matching the kernel category (e.g., an attention/ or
blackwell_fmha/ tests folder), update any relative imports inside the module to
the new location, and ensure all test callables in the file are properly
prefixed with test_ so pytest discovers them (check function names and any
parametrized fixtures used by functions in this module).
flashinfer/cute_dsl/attention/wrappers/batch_prefill.py (3)

393-396: Add strict=True to zip() for early shape-mismatch detection.

Using strict=True catches mismatched lengths between padding and shape_ early, improving debuggability.

Suggested fix
-        slices = tuple(slice(s, e) for s, e in zip(padding, shape_))
+        slices = tuple(slice(s, e) for s, e in zip(padding, shape_, strict=True))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/cute_dsl/attention/wrappers/batch_prefill.py` around lines 393 -
396, The slice construction using zip(padding, shape_) can silently ignore
length mismatches; update the tuple comprehension that defines slices (used to
create torch_tensor from torch_tensor_full and assigned to torch_tensor) to call
zip with strict=True (i.e., zip(padding, shape_, strict=True)) so any mismatch
between padding and shape_ raises immediately and makes debugging easier.

129-157: Prefix unused unpacked variables with underscore.

The variables q_ref, q_torch, k_ref, k_torch, v_ref, v_torch, and o_torch from create_and_pad_tensor() are intentionally unused (they're dummy tensors for CuTe JIT tracing). Prefix them with _ to indicate intent and silence linter warnings.

Suggested fix
-        q_ref, q_cute, q_torch = create_and_pad_tensor(
+        _q_ref, q_cute, _q_torch = create_and_pad_tensor(
             qo_shape,
             ...
         )
-        k_ref, k_cute, k_torch = create_and_pad_tensor(
+        _k_ref, k_cute, _k_torch = create_and_pad_tensor(
             kv_shape,
             ...
         )
-        v_ref, v_cute, v_torch = create_and_pad_tensor(
+        _v_ref, v_cute, _v_torch = create_and_pad_tensor(
             kv_shape,
             ...
         )

-        _, o_cute, o_torch = create_and_pad_tensor(
+        _, o_cute, _o_torch = create_and_pad_tensor(
             qo_shape,
             ...
         )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/cute_dsl/attention/wrappers/batch_prefill.py` around lines 129 -
157, The unpacked dummy tensors returned by create_and_pad_tensor (q_ref,
q_torch, k_ref, k_torch, v_ref, v_torch, o_torch) are unused and should be
prefixed with an underscore to indicate intentional unused variables and silence
linters; update the unpacking lines where create_and_pad_tensor is called (for
q_, k_, v_, and o_) to rename those specific variables to _q_ref/_q_torch,
_k_ref/_k_torch, _v_ref/_v_torch, and _o_torch (or similar underscore-prefixed
names) while keeping the used names q_cute/k_cute/v_cute/o_cute unchanged.

318-319: Minor: device=q.device is redundant with torch.empty_like.

torch.empty_like(q, ...) already inherits q's device by default.

Suggested fix
-            out = torch.empty_like(q, device=q.device)
+            out = torch.empty_like(q)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/cute_dsl/attention/wrappers/batch_prefill.py` around lines 318 -
319, In batch_prefill.py replace the redundant explicit device argument when
creating the empty tensor so that out is created with torch.empty_like(q)
instead of torch.empty_like(q, device=q.device); locate the assignment that sets
out when out is None (the one referencing variables out and q) and remove the
device=q.device parameter to rely on torch.empty_like inheriting q's device.
flashinfer/cute_dsl/attention/roles/softmax.py (1)

336-344: Redundant thread_idx computation.

thread_idx is computed identically at lines 337-344 and again at lines 366-373. The second computation overwrites the first with the same value.

Remove duplicate computation
         thread_idx = tidx % (
             self.threads_per_warp
             * (
                 len(self.softmax0_warp_ids)
                 if stage == 0
                 else len(self.softmax1_warp_ids)
             )
         )
         ...
         tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStSi)
-        thread_idx = tidx % (
-            self.threads_per_warp
-            * (
-                len(self.softmax0_warp_ids)
-                if stage == 0
-                else len(self.softmax1_warp_ids)
-            )
-        )
         thr_tmem_load = tiled_tmem_load.get_slice(thread_idx)

Also applies to: 366-373

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/cute_dsl/attention/roles/softmax.py` around lines 336 - 344, The
duplicated computation of thread_idx (calling cute.arch.thread_idx(), taking
tidx and computing tidx % (self.threads_per_warp * (len(self.softmax0_warp_ids)
if stage == 0 else len(self.softmax1_warp_ids)))) appears twice; remove the
redundant second block (the one at lines 366-373) so thread_idx remains computed
once and subsequent code uses the already-computed thread_idx from the first
occurrence; ensure any references after the removed block still rely on the
existing thread_idx variable and that no logic dependent on re-calling
cute.arch.thread_idx() is lost.
flashinfer/cute_dsl/attention/prefill.py (3)

155-156: Prefix unused s_k with underscore.

s_k is unpacked but never used. Prefix with _ to indicate intent.

-        b, s_q, s_k, h_q, h_k, d = problem_size
+        b, s_q, _s_k, h_q, h_k, d = problem_size
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/cute_dsl/attention/prefill.py` around lines 155 - 156, The tuple
unpacking of problem_size currently binds an unused variable s_k; change the
unpacking to use _s_k (or simply _ ) instead of s_k to signal it's intentionally
unused (e.g., replace "b, s_q, s_k, h_q, h_k, d = problem_size" with an
unpacking that prefixes s_k with an underscore) in the prefill logic where
variables b, s_q, h_q, h_k, d are used and h_r is computed from h_q and h_k.

45-51: Overly broad warning suppression may hide legitimate issues.

Suppressing all UserWarning messages (line 51) could mask important warnings from other parts of the codebase or dependencies. Consider scoping the suppression more narrowly, or applying it only within the specific context where the unrolling warning occurs.

Alternative: use a context manager at call sites
# Remove the global filter at module level
# warnings.filterwarnings("ignore", category=UserWarning)

# Instead, wrap specific calls that generate the warning:
import contextlib

`@contextlib.contextmanager`
def suppress_loop_unroll_warning():
    with warnings.catch_warnings():
        warnings.filterwarnings(
            "ignore",
            message="This loop is no longer unrolled and may cause performance regression",
        )
        yield
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/cute_dsl/attention/prefill.py` around lines 45 - 51, The module
currently suppresses all UserWarning globally by calling
warnings.filterwarnings("ignore", category=UserWarning); remove that broad
module-level filter and instead scope suppression to only the specific unroll
warning by introducing a context manager (e.g., suppress_loop_unroll_warning
using warnings.catch_warnings and warnings.filterwarnings with message="This
loop is no longer unrolled and may cause performance regression") and use that
context manager at the specific call sites in prefill.py where the unrolling
warning is raised so other UserWarnings remain visible.

385-386: Prefix unused tidx with underscore.

tidx from thread_idx() is unpacked but unused in the kernel entry. The variable is only used by roles that call thread_idx() themselves.

-        tidx, _, _ = cute.arch.thread_idx()
+        _tidx, _, _ = cute.arch.thread_idx()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/cute_dsl/attention/prefill.py` around lines 385 - 386, The
unpacked thread index variable tidx from cute.arch.thread_idx() is unused in the
kernel entry; change its name to _tidx to mark it as intentionally unused (i.e.,
replace "tidx, _, _ = cute.arch.thread_idx()" with "_tidx, _, _ =
cute.arch.thread_idx()") so linters/readers know it's unused while keeping the
other unpacked values and the existing warp_idx assignment (warp_idx =
cute.arch.make_warp_uniform(cute.arch.warp_idx())) intact.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@benchmarks/bench_blackwell_attention_cutedsl.py`:
- Around line 153-161: The script currently unconditionally runs an SM100-only
kernel in the __main__ block (calls to bench_fmha_cutedsl), which will
JIT/launch-fail on non-SM100 GPUs; add a GPU capability check before running the
default sweep: use torch.cuda.is_available() and
torch.cuda.get_device_capability() or
torch.cuda.get_device_properties(device).major/minor (or device name) to detect
whether the current GPU supports SM100, and if not, skip the default
bench_fmha_cutedsl(...) calls and exit or print a clear message; update the
__main__ section so the SM100-only sweep only runs when the capability check
passes.

In `@flashinfer/cute_dsl/attention/collective_builder.py`:
- Around line 96-98: The p_tmem_layout_staged is being created with the wrong
dtype (q_dtype) causing a mismatch with pv_tiled_mma which was created for V;
update the call to sm100_utils.make_smem_layout_a in collective_builder so
p_tmem_layout_staged uses v_dtype instead of q_dtype (the call that takes
pv_tiled_mma, config.pv_mma_tiler, q_dtype, mainloop.acc_stage should pass
v_dtype) to align the P buffer TMEM layout with the V buffer.

In `@flashinfer/cute_dsl/attention/fusion/mask.py`:
- Around line 45-53: Sliding-window masking currently centers the KV window on
raw Q indices (see MaskType.SLIDING_WINDOW_MASK handling using blk_coord,
tile_shape, window_left, seqlen_k) and ignores the Q/K length offset used by the
causal path; compute q_k_offset = seqlen_k - seqlen_q and add it to first_q and
last_q (or otherwise shift Q indices into KV space) before calculating min_kv,
max_kv, start_block, end_block, and any element masks; apply the same fix to the
other sliding-window blocks noted (around the other occurrences at the given
ranges) so all sliding-window computations use shifted Q indices into KV
coordinate space.

In `@flashinfer/cute_dsl/attention/fusion/variant.py`:
- Around line 551-554: SoftCappingAttention.score_mod calls non-existent
cute.arch.tanh; replace it with a local tanh approximation implemented using
available primitives (e.g., cute.arch.exp2 and cute.arch.rcp_approx) or a cheap
rational polynomial and call that from score_mod. Add a helper function (e.g.,
_tanh_approx(x)) in the same class or module and use it in
SoftCappingAttention.score_mod (referencing self.cap and self.rcp_cap as
before), implementing tanh(x) via exp2 by computing exp(-2*abs(x)) with
exp2(-2*abs(x)/ln2) plus sign handling or by a stable rational approximation
(polynomial numerator/denominator) and ensure the helper uses
cute.jit-compatible operations only.
- Around line 367-378: Update the class and relevant parameter docstrings to
state that sink values are expected in the logit domain (raw Q·K dot-product
units, unnormalized), not pre-scaled to log2; specifically mention this near the
documentation for the sink parameter(s) used by update_statistics and the
self.params/sink_raw conversion (which divides by scale/log2_e), and add a
cross-reference to sink_softmax in sink_attention_reference.py so callers know
sinks are concatenated to logits before any log2 scaling.

In `@flashinfer/cute_dsl/attention/pipeline_topology.py`:
- Around line 68-79: The Pipeline.dataclass field cluster_scale is ignored by
create_pipelines(), causing incorrect participant and barrier arrive counts;
either (preferred) honor it by multiplying the all-thread side's participant
counts when constructing producer/consumer groups and computing barrier arrive
counts for PipelineType values UMMA_ASYNC and ASYNC_UMMA (but leave TMA_UMMA
unchanged), i.e., when building groups from producer_warp_ids/consumer_warp_ids
in create_pipelines() multiply the thread counts by pipeline.cluster_scale and
use that scaled value when setting arrive counts for barriers/tx_count_key, or
fail fast by adding a check in create_pipelines() that raises a clear exception
if pipeline.cluster_scale != 1 so callers must handle scaling explicitly.

In `@flashinfer/cute_dsl/attention/roles/epilogue.py`:
- Around line 41-66: partition_output is incorrectly decorated with `@cute.jit`
while returning tensor objects (tOsO, tOgO) which violates the CuTe JIT
limitation; either remove the `@cute.jit` decorator from partition_output so it
runs as a normal Python method, or refactor it to avoid returning tensors by (a)
accepting preallocated output containers/handles and writing into them, or (b)
moving the cute.nvgpu.cpasync.tma_partition call out of the `@cute.jit` function
into a non-jit wrapper (e.g., create partition_output_nonjit that calls
cute.nvgpu.cpasync.tma_partition and returns tensors or change partition_output
to populate passed-in tensor references); update references to partition_output
accordingly so no `@cute.jit` function returns tensors (symbols: partition_output,
tOsO, tOgO, tma_partition, tma_atom_o).

In `@flashinfer/cute_dsl/attention/warp_schedule.py`:
- Around line 17-71: Add a fail-fast validation in WarpSchedule (implement in a
__post_init__ method) that verifies: 1) all_warp_ids (built from
softmax0_warp_ids, softmax1_warp_ids, correction_warp_ids, mma_warp_id,
load_warp_id, epilogue_warp_id, empty_warp_id) contain unique values and form a
contiguous range starting at 0 up to len(all_warp_ids)-1, and 2) the total
number of softmax warps (len(softmax0_warp_ids)+len(softmax1_warp_ids)) is
divisible by num_warps_per_warpgroup; on violation raise ValueError with a clear
message referencing the failing condition so consumers of num_warps,
threads_per_cta, and softmax_warpgroup_count cannot silently compute incorrect
sizes.

In `@flashinfer/cute_dsl/attention/wrappers/batch_prefill.py`:
- Around line 159-169: The NameError risk comes from params_cute being defined
only inside the if self._has_params block yet referenced later; fix by defining
params_cute = None before the if and only assigning it inside the block (where
you call from_dlpack) so later code can safely use the conditional expression
(params_cute.iterator if self._has_params else None); update references
involving self._has_params, _params_torch, and from_dlpack accordingly to rely
on the initialized params_cute variable.

---

Nitpick comments:
In `@benchmarks/bench_blackwell_attention_cutedsl.py`:
- Around line 7-8: The benchmark imports bench_gpu_time from a private path
(flashinfer.testing.utils); update the import to use the public helper by
replacing references to flashinfer.testing.utils with the public module
flashinfer.testing and import bench_gpu_time from flashinfer.testing (i.e., use
flashinfer.testing.bench_gpu_time) so the benchmark relies on the supported
public API rather than a private module.

In `@flashinfer/cute_dsl/attention/collective_builder.py`:
- Around line 163-186: Replace the anonymous SimpleNamespace return with a typed
dataclass (e.g., define `@dataclass` class AttentionCollective or
AttentionCollectiveConfig) that declares typed fields for each symbol currently
passed (qk_tiled_mma, pv_tiled_mma, tma_atom_q, tma_tensor_q, tma_atom_k,
tma_tensor_k, tma_atom_v, tma_tensor_v, tma_atom_o, tma_tensor_o,
q_smem_layout_staged, k_smem_layout_staged, p_tmem_layout_staged,
v_smem_layout_staged, o_smem_layout_staged, SharedStorage, tma_copy_q_bytes,
tma_copy_kv_bytes, cluster_shape_mnk, cluster_layout_vmnk, epi_tile, o_layout),
add appropriate type hints (use typing.Any or more specific types if known),
import dataclasses and typing, instantiate and return that dataclass instead of
SimpleNamespace, and update any consumers to accept the new dataclass type for
improved IDE autocompletion and type checking.

In `@flashinfer/cute_dsl/attention/prefill.py`:
- Around line 155-156: The tuple unpacking of problem_size currently binds an
unused variable s_k; change the unpacking to use _s_k (or simply _ ) instead of
s_k to signal it's intentionally unused (e.g., replace "b, s_q, s_k, h_q, h_k, d
= problem_size" with an unpacking that prefixes s_k with an underscore) in the
prefill logic where variables b, s_q, h_q, h_k, d are used and h_r is computed
from h_q and h_k.
- Around line 45-51: The module currently suppresses all UserWarning globally by
calling warnings.filterwarnings("ignore", category=UserWarning); remove that
broad module-level filter and instead scope suppression to only the specific
unroll warning by introducing a context manager (e.g.,
suppress_loop_unroll_warning using warnings.catch_warnings and
warnings.filterwarnings with message="This loop is no longer unrolled and may
cause performance regression") and use that context manager at the specific call
sites in prefill.py where the unrolling warning is raised so other UserWarnings
remain visible.
- Around line 385-386: The unpacked thread index variable tidx from
cute.arch.thread_idx() is unused in the kernel entry; change its name to _tidx
to mark it as intentionally unused (i.e., replace "tidx, _, _ =
cute.arch.thread_idx()" with "_tidx, _, _ = cute.arch.thread_idx()") so
linters/readers know it's unused while keeping the other unpacked values and the
existing warp_idx assignment (warp_idx =
cute.arch.make_warp_uniform(cute.arch.warp_idx())) intact.

In `@flashinfer/cute_dsl/attention/roles/softmax.py`:
- Around line 336-344: The duplicated computation of thread_idx (calling
cute.arch.thread_idx(), taking tidx and computing tidx % (self.threads_per_warp
* (len(self.softmax0_warp_ids) if stage == 0 else len(self.softmax1_warp_ids))))
appears twice; remove the redundant second block (the one at lines 366-373) so
thread_idx remains computed once and subsequent code uses the already-computed
thread_idx from the first occurrence; ensure any references after the removed
block still rely on the existing thread_idx variable and that no logic dependent
on re-calling cute.arch.thread_idx() is lost.

In `@flashinfer/cute_dsl/attention/scheduler/persistent.py`:
- Around line 38-45: In __new_from_mlir_values__ update the zip over
[self.is_persistent, self.problem_shape_mbh] and self._values_pos to use
zip(..., strict=True) to fail loudly on length mismatches, and when returning
the FmhaStaticTileSchedulerParams instance forward the current object's ip
parameter (pass loc=self._loc, ip=self.ip) so the new instance receives ip as
well; this touches the __new_from_mlir_values__ method, the attributes
self.is_persistent, self.problem_shape_mbh, self._values_pos, and the
FmhaStaticTileSchedulerParams constructor call.
- Around line 148-158: The hardcoded assertion in __new_from_mlir_values__
(assert len(values) == 10) is fragile; change it to compute the expected MLIR
value count by summing the MLIR-value counts of the constituent objects
(self._params, self._current_work_linear_idx, self._blk_coord, self._grid_shape)
using whatever helper/attribute your cutlass layer exposes (e.g., a mlir value
count helper or by querying each object's MLIR representation), then compare
len(values) to that computed total and raise a ValueError with a descriptive
message if mismatched; update the slicing logic that builds new_params,
new_current_work_linear_idx, new_blk_coord, and new_grid_shape to use those
computed per-object counts instead of fixed indices so FmhaStaticTileScheduler
construction remains correct.

In `@flashinfer/cute_dsl/attention/tmem_layout.py`:
- Around line 35-49: Extract the literal 512 used for TMEM capacity into a
module-level constant (e.g., SM100_TMEM_CAPACITY_COLUMNS = 512) and replace the
local variable in TmemLayout.from_config so the function uses that constant
instead of a magic number; update the top of the module with the constant and
ensure TmemLayout.from_config (which takes AttentionConfig and reads
config.mma_tiler[0]) references the new constant for alloc_cols so other code
can reuse the hardware-capacity value.

In `@flashinfer/cute_dsl/attention/wrappers/batch_prefill.py`:
- Around line 393-396: The slice construction using zip(padding, shape_) can
silently ignore length mismatches; update the tuple comprehension that defines
slices (used to create torch_tensor from torch_tensor_full and assigned to
torch_tensor) to call zip with strict=True (i.e., zip(padding, shape_,
strict=True)) so any mismatch between padding and shape_ raises immediately and
makes debugging easier.
- Around line 129-157: The unpacked dummy tensors returned by
create_and_pad_tensor (q_ref, q_torch, k_ref, k_torch, v_ref, v_torch, o_torch)
are unused and should be prefixed with an underscore to indicate intentional
unused variables and silence linters; update the unpacking lines where
create_and_pad_tensor is called (for q_, k_, v_, and o_) to rename those
specific variables to _q_ref/_q_torch, _k_ref/_k_torch, _v_ref/_v_torch, and
_o_torch (or similar underscore-prefixed names) while keeping the used names
q_cute/k_cute/v_cute/o_cute unchanged.
- Around line 318-319: In batch_prefill.py replace the redundant explicit device
argument when creating the empty tensor so that out is created with
torch.empty_like(q) instead of torch.empty_like(q, device=q.device); locate the
assignment that sets out when out is None (the one referencing variables out and
q) and remove the device=q.device parameter to rely on torch.empty_like
inheriting q's device.

In `@tests/test_blackwell_fmha_attention.py`:
- Around line 1-13: The test module test_blackwell_fmha_attention.py is at the
tests/ root but belongs in the attention-specific kernel tests; move this suite
into a feature-specific tests subdirectory matching the kernel category (e.g.,
an attention/ or blackwell_fmha/ tests folder), update any relative imports
inside the module to the new location, and ensure all test callables in the file
are properly prefixed with test_ so pytest discovers them (check function names
and any parametrized fixtures used by functions in this module).

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: f4d60aef-cbff-41af-8cdd-f4914f9854d7

📥 Commits

Reviewing files that changed from the base of the PR and between e4dc66f and 9f0ba5e.

📒 Files selected for processing (24)
  • benchmarks/bench_blackwell_attention_cutedsl.py
  • flashinfer/cute_dsl/attention/__init__.py
  • flashinfer/cute_dsl/attention/collective_builder.py
  • flashinfer/cute_dsl/attention/config.py
  • flashinfer/cute_dsl/attention/fusion/__init__.py
  • flashinfer/cute_dsl/attention/fusion/mask.py
  • flashinfer/cute_dsl/attention/fusion/variant.py
  • flashinfer/cute_dsl/attention/mainloop_spec.py
  • flashinfer/cute_dsl/attention/pipeline_topology.py
  • flashinfer/cute_dsl/attention/prefill.py
  • flashinfer/cute_dsl/attention/roles/__init__.py
  • flashinfer/cute_dsl/attention/roles/correction.py
  • flashinfer/cute_dsl/attention/roles/epilogue.py
  • flashinfer/cute_dsl/attention/roles/loader_tma.py
  • flashinfer/cute_dsl/attention/roles/mma.py
  • flashinfer/cute_dsl/attention/roles/softmax.py
  • flashinfer/cute_dsl/attention/roles/softmax_math.py
  • flashinfer/cute_dsl/attention/scheduler/__init__.py
  • flashinfer/cute_dsl/attention/scheduler/persistent.py
  • flashinfer/cute_dsl/attention/tmem_layout.py
  • flashinfer/cute_dsl/attention/warp_schedule.py
  • flashinfer/cute_dsl/attention/wrappers/__init__.py
  • flashinfer/cute_dsl/attention/wrappers/batch_prefill.py
  • tests/test_blackwell_fmha_attention.py

Comment on lines +153 to +161
if __name__ == "__main__":
bench_fmha_cutedsl(128, 512, 32, 128, True, torch.bfloat16)
bench_fmha_cutedsl(64, 1024, 32, 128, True, torch.bfloat16)
bench_fmha_cutedsl(32, 2048, 32, 128, True, torch.bfloat16)
bench_fmha_cutedsl(16, 4096, 32, 128, True, torch.bfloat16)
bench_fmha_cutedsl(8, 8192, 32, 128, True, torch.bfloat16)
bench_fmha_cutedsl(4, 16384, 32, 128, True, torch.bfloat16)
bench_fmha_cutedsl(2, 32768, 32, 128, True, torch.bfloat16)
bench_fmha_cutedsl(1, 65536, 32, 128, True, torch.bfloat16)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Skip the default sweep on unsupported GPUs.

The __main__ path unconditionally launches an SM100-only kernel. Running this script on another CUDA box will fail in JIT/launch instead of exiting cleanly with a clear message.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_blackwell_attention_cutedsl.py` around lines 153 - 161, The
script currently unconditionally runs an SM100-only kernel in the __main__ block
(calls to bench_fmha_cutedsl), which will JIT/launch-fail on non-SM100 GPUs; add
a GPU capability check before running the default sweep: use
torch.cuda.is_available() and torch.cuda.get_device_capability() or
torch.cuda.get_device_properties(device).major/minor (or device name) to detect
whether the current GPU supports SM100, and if not, skip the default
bench_fmha_cutedsl(...) calls and exit or print a clear message; update the
__main__ section so the SM100-only sweep only runs when the capability check
passes.

Comment on lines +96 to +98
p_tmem_layout_staged = sm100_utils.make_smem_layout_a(
pv_tiled_mma, config.pv_mma_tiler, q_dtype, mainloop.acc_stage,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check how p_tmem_layout_staged is used and if q_dtype is the correct type
rg -n "p_tmem_layout" --type py flashinfer/cute_dsl/attention/

Repository: flashinfer-ai/flashinfer

Length of output: 824


🏁 Script executed:

#!/bin/bash
# Get the full function signature and beginning to see all dtype parameters
sed -n '50,120p' flashinfer/cute_dsl/attention/collective_builder.py

Repository: flashinfer-ai/flashinfer

Length of output: 2667


🏁 Script executed:

#!/bin/bash
# Check sm100_utils.make_smem_layout_a to see what dtype parameter does
rg -n "def make_smem_layout_a" --type py flashinfer/cute_dsl/attention/

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

#!/bin/bash
# Check what dtypes are typically used - look for q_dtype, k_dtype, v_dtype, p_dtype
rg -n "q_dtype|k_dtype|v_dtype|p_dtype|intermediate_dtype" --type py flashinfer/cute_dsl/attention/collective_builder.py | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 900


P buffer TMEM layout should use v_dtype, not q_dtype.

The P buffer (attention weights) shares the same pv_tiled_mma with V, which is created with v_dtype. However, p_tmem_layout_staged is created with q_dtype instead, causing a dtype mismatch. Change line 97 to use v_dtype:

Before:
p_tmem_layout_staged = sm100_utils.make_smem_layout_a(
    pv_tiled_mma, config.pv_mma_tiler, q_dtype, mainloop.acc_stage,
)
After:
p_tmem_layout_staged = sm100_utils.make_smem_layout_a(
    pv_tiled_mma, config.pv_mma_tiler, v_dtype, mainloop.acc_stage,
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/cute_dsl/attention/collective_builder.py` around lines 96 - 98,
The p_tmem_layout_staged is being created with the wrong dtype (q_dtype) causing
a mismatch with pv_tiled_mma which was created for V; update the call to
sm100_utils.make_smem_layout_a in collective_builder so p_tmem_layout_staged
uses v_dtype instead of q_dtype (the call that takes pv_tiled_mma,
config.pv_mma_tiler, q_dtype, mainloop.acc_stage should pass v_dtype) to align
the P buffer TMEM layout with the V buffer.

Comment on lines +45 to +53
elif mask_type == MaskType.SLIDING_WINDOW_MASK:
first_q = blk_coord[0] * tile_shape[0]
last_q = (blk_coord[0] + 1) * tile_shape[0] - 1
min_kv = cutlass.max(0, first_q - window_left)
max_kv = cutlass.min(seqlen_k - 1, last_q + window_left)
start_block = min_kv // tile_shape[1]
end_block = cute.ceil_div(max_kv + 1, tile_shape[1])
result = end_block - start_block
return result
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Sliding-window masking still ignores the Q/K length offset.

The causal path accounts for seqlen_k - seqlen_q, but the sliding-window branches compute trip counts, start blocks, and element masks from raw Q indices. For suffix-prefill or ragged batches where seqlen_q != seqlen_k, the window gets centered on the wrong KV region.

🛠️ One way to thread the existing offset through the sliding-window path
     elif mask_type == MaskType.SLIDING_WINDOW_MASK:
-        first_q = blk_coord[0] * tile_shape[0]
-        last_q = (blk_coord[0] + 1) * tile_shape[0] - 1
+        q_offset = seqlen_k - seqlen_q
+        first_q = blk_coord[0] * tile_shape[0] + q_offset
+        last_q = (blk_coord[0] + 1) * tile_shape[0] - 1 + q_offset
         min_kv = cutlass.max(0, first_q - window_left)
         max_kv = cutlass.min(seqlen_k - 1, last_q + window_left)
         start_block = min_kv // tile_shape[1]
         end_block = cute.ceil_div(max_kv + 1, tile_shape[1])
         result = end_block - start_block
@@
     elif mask_type == MaskType.SLIDING_WINDOW_MASK:
         trip_count = get_trip_count(
-            mask_type, window_left, blk_coord, tile_shape, seqlen_k
+            mask_type, window_left, blk_coord, tile_shape, seqlen_k, seqlen_q
         )
         result = trip_count
@@
     if cutlass.const_expr(mask_type == MaskType.SLIDING_WINDOW_MASK):
-        first_q = blk_coord[0] * tile_shape[0]
+        q_offset = seqlen_k - seqlen_q
+        first_q = blk_coord[0] * tile_shape[0] + q_offset
         min_kv = cutlass.max(0, first_q - window_left)
         return min_kv // tile_shape[1]
@@
     elif mask_type == MaskType.SLIDING_WINDOW_MASK:
         for i in range(cute.size(acc_qk)):
             pos = index_qk[i]
-            if pos[1] - pos[0] > window_left or pos[0] - pos[1] > window_left or pos[1] >= seqlen_k:
+            q_pos = pos[0] + causal_offset
+            if pos[1] - q_pos > window_left or q_pos - pos[1] > window_left or pos[1] >= seqlen_k:
                 acc_qk[i] = -Float32.inf

Also applies to: 85-88, 129-143, 146-170

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/cute_dsl/attention/fusion/mask.py` around lines 45 - 53,
Sliding-window masking currently centers the KV window on raw Q indices (see
MaskType.SLIDING_WINDOW_MASK handling using blk_coord, tile_shape, window_left,
seqlen_k) and ignores the Q/K length offset used by the causal path; compute
q_k_offset = seqlen_k - seqlen_q and add it to first_q and last_q (or otherwise
shift Q indices into KV space) before calculating min_kv, max_kv, start_block,
end_block, and any element masks; apply the same fix to the other sliding-window
blocks noted (around the other occurrences at the given ranges) so all
sliding-window computations use shifted Q indices into KV coordinate space.

Comment on lines +367 to +378
@cute.jit
def update_statistics(self, kv_tile_idx, qo_head_idx, m, d, scale):
log2_e = math.log2(math.exp(1.0))
sink_raw = (
self.params[qo_head_idx] * log2_e / scale
if kv_tile_idx == 0
else -math.inf
)
m_new = sink_raw if sink_raw > m else m
rescale = cute.arch.exp2((m - m_new) * scale)
d_new = cute.arch.exp2((sink_raw - m_new) * scale) + d * rescale
return m_new, d_new
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🌐 Web query:

CuTe DSL attention sink implementation softmax denominator

💡 Result:

No specific "attention sink implementation" in CuTe DSL was found. CuTe DSL (from NVIDIA CUTLASS) is used for high-performance FlashAttention kernels, including softmax computation via online softmax for numerical stability. The softmax denominator (row sums of exp(scores - row_max)) is computed using tiled reductions in shared memory across blocks/tiles of the attention matrix, leveraging CuTe's tensor layouts, TensorSSA for vectorized ops, and arch-specific reductions like cute.arch.warp_reduction_sum (as in RMSNorm examples, analogous for softmax rows). Official examples like CUTLASS's flash_attention_v2.py feature softmax_rescale_O for online softmax fusion with GEMM. Repos like luliyucoordinate/cute-flash-attention and HuyNguyen-hust/flash-attn-101 implement FlashAttention in CuTe, using standard online softmax without special "sink" handling—sinks are a training/model phenomenon, not a kernel feature. FlashAttention-3/4 in Dao-AILab/flash-attention use CuTe DSL for Hopper/Blackwell with optimized exp2/emulation and reductions for the denominator. Attention sinks refer to probability mass concentrating on fixed positions in softmax attention (e.g., arXiv:2603.11487), but kernel implementations compute general softmax denominators via parallel reductions, not sink-specific logic.

Citations:


🏁 Script executed:

find . -type f -name "variant.py" | grep -E "cute_dsl|attention"

Repository: flashinfer-ai/flashinfer

Length of output: 117


🏁 Script executed:

if [ -f "flashinfer/cute_dsl/attention/fusion/variant.py" ]; then
  wc -l flashinfer/cute_dsl/attention/fusion/variant.py
fi

Repository: flashinfer-ai/flashinfer

Length of output: 119


🏁 Script executed:

if [ -f "flashinfer/cute_dsl/attention/fusion/variant.py" ]; then
  sed -n '350,395p' flashinfer/cute_dsl/attention/fusion/variant.py
fi

Repository: flashinfer-ai/flashinfer

Length of output: 1402


🏁 Script executed:

# Search for scale parameter definition and sink-related documentation
rg -n "scale\s*=" flashinfer/cute_dsl/attention/fusion/variant.py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 335


🏁 Script executed:

# Look for documentation about sink semantics
rg -n -B3 -A3 "sink" flashinfer/cute_dsl/attention/fusion/variant.py | head -40

Repository: flashinfer-ai/flashinfer

Length of output: 1858


🏁 Script executed:

# Search for params initialization and documentation
rg -n -B5 "self.params" flashinfer/cute_dsl/attention/fusion/variant.py | head -50

Repository: flashinfer-ai/flashinfer

Length of output: 2271


🏁 Script executed:

sed -n '340,410p' flashinfer/cute_dsl/attention/fusion/variant.py

Repository: flashinfer-ai/flashinfer

Length of output: 2239


🏁 Script executed:

# Search for AttentionWithSink usage and test cases
rg -n "AttentionWithSink" --type py

Repository: flashinfer-ai/flashinfer

Length of output: 839


🏁 Script executed:

# Look for documentation around the sink parameter
sed -n '340,365p' flashinfer/cute_dsl/attention/fusion/variant.py

Repository: flashinfer-ai/flashinfer

Length of output: 790


🏁 Script executed:

# Look at test cases to understand sink value semantics
sed -n '420,450p' tests/test_blackwell_fmha_attention.py

Repository: flashinfer-ai/flashinfer

Length of output: 1041


🏁 Script executed:

# Check more of the test context
sed -n '715,780p' tests/test_blackwell_fmha_attention.py

Repository: flashinfer-ai/flashinfer

Length of output: 2614


🏁 Script executed:

# Look for any documentation about sink value domain/units
rg -n -B5 -A5 "sink" flashinfer/cute_dsl/attention/fusion/variant.py | tail -80

Repository: flashinfer-ai/flashinfer

Length of output: 2864


🏁 Script executed:

# Check if there's a reference paper or documentation file
find . -type f -name "*.md" -o -name "*.rst" | xargs grep -l "sink" 2>/dev/null

Repository: flashinfer-ai/flashinfer

Length of output: 52


🏁 Script executed:

# Find the reference implementation to understand sink semantics
rg -n "def attention_ref" tests/test_blackwell_fmha_attention.py

Repository: flashinfer-ai/flashinfer

Length of output: 89


🏁 Script executed:

# Look at the reference implementation
sed -n '1,120p' tests/test_blackwell_fmha_attention.py

Repository: flashinfer-ai/flashinfer

Length of output: 3643


🏁 Script executed:

# Search for attention_ref implementation across the repo
find . -type f -name "*.py" -exec grep -l "def attention_ref" {} \;

Repository: flashinfer-ai/flashinfer

Length of output: 230


🏁 Script executed:

# Find the sink_softmax implementation
find . -type f -name "sink_attention_reference.py"

Repository: flashinfer-ai/flashinfer

Length of output: 116


🏁 Script executed:

# Search for sink_softmax definition
rg -n "def sink_softmax" --type py

Repository: flashinfer-ai/flashinfer

Length of output: 149


🏁 Script executed:

cat tests/test_helpers/sink_attention_reference.py

Repository: flashinfer-ai/flashinfer

Length of output: 16156


🏁 Script executed:

# Look at how sink values are used in tests - what domain are they in?
sed -n '420,435p' tests/test_blackwell_fmha_attention.py

Repository: flashinfer-ai/flashinfer

Length of output: 710


🏁 Script executed:

# Check the documentation docstring more carefully
sed -n '337,356p' flashinfer/cute_dsl/attention/fusion/variant.py

Repository: flashinfer-ai/flashinfer

Length of output: 696


🏁 Script executed:

# Look for any comments about sink semantics in scale context
rg -n -B10 -A10 "log2_e\|log2.math" flashinfer/cute_dsl/attention/fusion/variant.py | head -60

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Search for how scale is set/passed to update_statistics
rg -n "update_statistics.*scale" --type py -B3 -A3 | head -60

Repository: flashinfer-ai/flashinfer

Length of output: 2605


🏁 Script executed:

# Look for where scale is initialized/computed in the wrapper
rg -n "scale.*=" flashinfer/cute_dsl/attention/ --type py | grep -v test | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 3044


🏁 Script executed:

# Check if there's any comment about sink value domain in tests
rg -n -B5 -A5 "torch.randn.*sink" tests/ --type py | head -80

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Look at the softmax_math.py to understand how m and d work
sed -n '1,50p' flashinfer/cute_dsl/attention/roles/softmax_math.py

Repository: flashinfer-ai/flashinfer

Length of output: 1430


🏁 Script executed:

# Check what domain m and d are actually in
rg -n "m_new\|d_new" flashinfer/cute_dsl/attention/fusion/variant.py -B2 -A2

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Search for documentation about expected sink value domain/units
rg -n "sink" flashinfer/cute_dsl/attention/fusion/variant.py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 1018


🏁 Script executed:

# Let me check the documentation more carefully for scale parameter
sed -n '50,100p' flashinfer/cute_dsl/attention/fusion/variant.py

Repository: flashinfer-ai/flashinfer

Length of output: 1926


🏁 Script executed:

# Look at the complete context of the update_statistics docstring
sed -n '275,315p' flashinfer/cute_dsl/attention/fusion/variant.py

Repository: flashinfer-ai/flashinfer

Length of output: 1503


🏁 Script executed:

# Let me verify the mathematical simplification one more time
python3 << 'EOF'
import math

# Verify: if scale = log2(e) * sm_scale, then params * log2_e / scale = params / sm_scale

log2_e = math.log2(math.exp(1.0))
print(f"log2(e) = {log2_e}")
print(f"log2(e) ≈ 1.4426950408889634")

# For any sm_scale, verify the simplification
sm_scale = 1.0 / math.sqrt(128)  # example: head_dim = 128
scale = log2_e * sm_scale

print(f"\nFor head_dim=128: sm_scale = {sm_scale}")
print(f"scale = log2(e) * sm_scale = {scale}")

# Test the simplification with a sample sink value
params = 2.5
sink_raw_original = params * log2_e / scale
sink_raw_simplified = params / sm_scale

print(f"\nFor params = {params}:")
print(f"sink_raw (original formula) = params * log2_e / scale = {sink_raw_original}")
print(f"sink_raw (simplified formula) = params / sm_scale = {sink_raw_simplified}")
print(f"Match: {abs(sink_raw_original - sink_raw_simplified) < 1e-10}")
EOF

Repository: flashinfer-ai/flashinfer

Length of output: 401


🏁 Script executed:

# Check if sink value domain is documented in docstring or README
find . -type f \( -name "README.md" -o -name "*.rst" \) -exec grep -l "sink" {} \;

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Final check: Look at the complete example in the docstring
sed -n '131,150p' flashinfer/cute_dsl/attention/fusion/variant.py

Repository: flashinfer-ai/flashinfer

Length of output: 751


Clarify the expected domain of sink tensor values in the docstring.

The mathematical conversion sink_raw = params * log2_e / scale is correct and simplifies to params / sm_scale as stated. However, the class docstring does not document the domain in which sink values are expected. The reference implementation (sink_softmax in sink_attention_reference.py) concatenates sink values directly to logits, confirming they should be in the logit domain (raw QK dot-product units, unnormalized). Add this to the sink parameter documentation to clarify that users must provide sink values in the logit domain, not pre-converted to log2 scale.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/cute_dsl/attention/fusion/variant.py` around lines 367 - 378,
Update the class and relevant parameter docstrings to state that sink values are
expected in the logit domain (raw Q·K dot-product units, unnormalized), not
pre-scaled to log2; specifically mention this near the documentation for the
sink parameter(s) used by update_statistics and the self.params/sink_raw
conversion (which divides by scale/log2_e), and add a cross-reference to
sink_softmax in sink_attention_reference.py so callers know sinks are
concatenated to logits before any log2 scaling.

Comment on lines +551 to +554
@cute.jit
def score_mod(self, score, batch_idx, qo_idx, kv_idx,
qo_head_idx, kv_head_idx):
return self.cap * cute.arch.tanh(score * self.rcp_cap)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Search for tanh usage or definition in cutlass/cute
rg -n "tanh" --type py -g '!*test*' | head -30

# Check if there's documentation or other usage
rg -n "cute\.arch\.tanh" --type py

Repository: flashinfer-ai/flashinfer

Length of output: 3170


🏁 Script executed:

#!/bin/bash
# First, let's look at the file structure and imports in variant.py
head -50 flashinfer/cute_dsl/attention/fusion/variant.py | cat -n

# Check imports related to cute
rg -n "^import cute|^from cute" flashinfer/cute_dsl/attention/fusion/variant.py

Repository: flashinfer-ai/flashinfer

Length of output: 2442


🏁 Script executed:

#!/bin/bash
# Look for cute.arch definition or where it comes from
rg -n "cute\.arch" flashinfer/cute_dsl/attention/fusion/variant.py -B 2 -A 2 | head -50

# Check if cute is a local module or external dependency
find . -name "cute.py" -o -name "cute" -type d | grep -v __pycache__

Repository: flashinfer-ai/flashinfer

Length of output: 890


🏁 Script executed:

#!/bin/bash
# Check for cute library in requirements or setup files
find . -name "requirements*.txt" -o -name "setup.py" -o -name "pyproject.toml" | xargs grep -l cute 2>/dev/null | head -10

# Look at the actual file content around the usages
sed -n '550,560p' flashinfer/cute_dsl/attention/fusion/variant.py | cat -n
sed -n '165,175p' flashinfer/cute_dsl/attention/fusion/variant.py | cat -n

Repository: flashinfer-ai/flashinfer

Length of output: 634


🏁 Script executed:

#!/bin/bash
# Look for the cute module or arch submodule definition
find . -name "*.py" -path "*/cute*" | head -20

# Search for arch module or its definition
rg -n "class arch|def tanh|def exp2|def rcp_approx" --type py | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 1477


🏁 Script executed:

#!/bin/bash
# Check if there are any bindings or stubs for cute
rg -n "cutlass" . --type py -l | head -20

# Look for any .pyi stub files or type hints
find . -name "*.pyi" | head -10

# Check requirements or dependencies
cat pyproject.toml 2>/dev/null | head -50
cat setup.py 2>/dev/null | head -50

Repository: flashinfer-ai/flashinfer

Length of output: 2487


🏁 Script executed:

#!/bin/bash
# Search for tests that use cute.arch functions
rg -n "cute\.arch\.(tanh|exp2|rcp_approx)" --type py -l

# Check if there's error handling around these calls
rg -n "cute\.arch" flashinfer/cute_dsl/attention/fusion/variant.py -B 3 -A 3 | head -80

Repository: flashinfer-ai/flashinfer

Length of output: 1476


🏁 Script executed:

#!/bin/bash
# Check what's actually exported from cutlass.cute
rg -n "from cutlass" flashinfer/cute_dsl/attention/fusion/variant.py -A 5 -B 5

# Look for how cute module is used in other files
rg -n "cute\.arch" flashinfer/fused_moe/cute_dsl/blackwell/utils.py -A 2 -B 2 | head -30

# Check if there are any imports or type stubs
rg -n "cutlass\.cute" . --type py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 2104


🏁 Script executed:

#!/bin/bash
# Search for any CUTLASS bindings or Python wrapping
find . -name "*cute*" -type f | grep -E "\.(py|pyi|so|pyd)" | head -20

# Look for documentation or examples of cute.arch usage
rg -n "cute\.arch\." . --type py -l | xargs head -100 | head -50

Repository: flashinfer-ai/flashinfer

Length of output: 2173


🌐 Web query:

CUTLASS cute arch tanh exp2 rcp_approx Python bindings documentation

💡 Result:

CUTLASS provides Python bindings through its CuTe DSL (Domain-Specific Language), enabling high-performance CUDA kernel development in Python. The CuTe architecture module (cutlass.cute.arch) includes wrappers for NVVM-level MLIR operations that integrate with CuTe Python types. Key functions in cutlass.cute.arch relevant to the query: - rcp_approx(a: float | cutlass.cute.typing.Float32, *, loc=None, ip=None) → cutlass.cute.typing.Float32: Computes approximate reciprocal. - exp2(a: float | cutlass.cute.typing.Float32, *, loc=None, ip=None) → cutlass.cute.typing.Float32: Computes base-2 exponential. - No explicit tanh function found in the arch module; it may require composition of other operations or be available elsewhere in CuTe. Official documentation: - Primary source: https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/cute_dsl_api/cute_arch.html (cutlass.cute.arch API) - Installation: Use setup.sh from CUTLASS repo (e.g., ./cutlass/python/CuTeDSL/setup.sh --cu13 for CUDA 13.1) or pip install nvidia-cutlass-dsl[cu13]. - GitHub: https://github.com/NVIDIA/cutlass (v4.4.2 as of 2026-03-17). Example usage in a @cute.jit kernel: import cutlass.cute as cute @cute.jit def kernel(a: cute.Tensor): recip = cute.arch.rcp_approx(a[0]) exp = cute.arch.exp2(recip) CuTe is part of CUTLASS 4.x, supporting NVIDIA architectures from Ampere to Blackwell.

Citations:


🏁 Script executed:

#!/bin/bash
# Search for tanh in the entire codebase to see how it should be implemented
rg -n "tanh" flashinfer/cute_dsl --type py -A 3 -B 3

# Check if there's cute.math module with tanh
rg -n "cute\.math\." flashinfer/cute_dsl --type py | head -20

# Look for any definitions of tanh function
rg -n "def tanh" . --type py

Repository: flashinfer-ai/flashinfer

Length of output: 2012


🏁 Script executed:

#!/bin/bash
# Check if there are any test files that test SoftCappingAttention
rg -n "SoftCappingAttention" . --type py

# Look for uses of score_mod in tests
rg -n "score_mod" tests/ --type py -A 5 | head -40

Repository: flashinfer-ai/flashinfer

Length of output: 973


🌐 Web query:

CUTLASS cute tanh hyperbolic tangent math operations

💡 Result:

CUTLASS and CuTe (CUTE) do not provide built-in dedicated support for hyperbolic tangent (tanh) math operations within their core GEMM kernels or tensor algorithms. Tanh is available in related NVIDIA libraries but requires custom implementation for fusion in CUTLASS epilogues or CuTe usage. Key findings: - CuTe tensor algorithms (cute/algorithm/) support copy, gemm, axpby, fill, clear, but no tanh or unary activations like tanh listed in official docs. - CUTLASS epilogues support fused activations via cutlass::epilogue::thread::LinearCombination + Activation functors (e.g., ReLU, GELU, SiLU/LeakyReLU in activation.h), but no built-in tanh functor. Custom tanh can be implemented by extending LinearCombinationGeneric or Epilogue Visitor Trees (EVT) for CUTLASS 3.x, as confirmed in GitHub issues where users request sigmoid/tanh fusion. - Separate NVIDIA libraries provide tanh: cuda.tile.tanh in cuTile Python (tile-level), nv_tanh in libdevice (scalar). These are not directly integrated into CUTLASS/CuTe GEMM epilogues. - For performance-critical use (e.g., fused GEMM + tanh), implement custom epilogue op calling CUDA math intrinsics like nv_tanhf (fast float version), or approximate tanh (e.g., K-TanH from research papers). To use: In CUTLASS 3.x, extend Sm90EVT or similar with custom compute node using tanh; in 2.x, subclass LinearCombination with custom activation functor calling tanh. No official CUTLASS 4.x (as of 2026-03) release notes mention tanh addition.

Citations:


cute.arch.tanh is not available in CUTLASS CuTe — this will cause JIT compilation to fail.

SoftCappingAttention.score_mod uses cute.arch.tanh, but this function does not exist in the CuTe DSL. CUTLASS/CuTe provides cute.arch.rcp_approx and cute.arch.exp2, but not tanh. The kernel will fail at JIT compile time with an AttributeError.

Implement tanh using available operations (e.g., composition via exp2 and rcp_approx) or use an alternative approximation.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/cute_dsl/attention/fusion/variant.py` around lines 551 - 554,
SoftCappingAttention.score_mod calls non-existent cute.arch.tanh; replace it
with a local tanh approximation implemented using available primitives (e.g.,
cute.arch.exp2 and cute.arch.rcp_approx) or a cheap rational polynomial and call
that from score_mod. Add a helper function (e.g., _tanh_approx(x)) in the same
class or module and use it in SoftCappingAttention.score_mod (referencing
self.cap and self.rcp_cap as before), implementing tanh(x) via exp2 by computing
exp(-2*abs(x)) with exp2(-2*abs(x)/ln2) plus sign handling or by a stable
rational approximation (polynomial numerator/denominator) and ensure the helper
uses cute.jit-compatible operations only.

Comment on lines +68 to +79
When cluster_scale > 1, the all-thread side of
UMMA_ASYNC / ASYNC_UMMA pipelines multiplies its thread count by
cluster_scale. TMA_UMMA pipelines are unaffected (leader-only on both sides).
"""

name: str
pipeline_type: PipelineType
stages: int
producer_warp_ids: Tuple[int, ...]
consumer_warp_ids: Tuple[int, ...]
tx_count_key: str | None = None
cluster_scale: int = 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

cluster_scale is currently a no-op.

The dataclass docs say this field changes participant counts, but create_pipelines() always instantiates the producer/consumer groups as if it were 1. Any topology that starts using cluster_scale > 1 will build barriers with the wrong arrive counts.

🛠️ Either honor the field or fail fast when it is set
             prod_threads = edge.pipeline_type.producer_thread_count(
                 len(edge.producer_warp_ids), threads_per_warp
             )
             cons_threads = edge.pipeline_type.consumer_thread_count(
                 len(edge.consumer_warp_ids), threads_per_warp
             )
+            if edge.cluster_scale != 1:
+                if edge.pipeline_type == PipelineType.UMMA_ASYNC:
+                    cons_threads *= edge.cluster_scale
+                elif edge.pipeline_type == PipelineType.ASYNC_UMMA:
+                    prod_threads *= edge.cluster_scale
+                else:
+                    raise ValueError(
+                        f"cluster_scale is unsupported for {edge.pipeline_type}"
+                    )

Also applies to: 127-147

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/cute_dsl/attention/pipeline_topology.py` around lines 68 - 79, The
Pipeline.dataclass field cluster_scale is ignored by create_pipelines(), causing
incorrect participant and barrier arrive counts; either (preferred) honor it by
multiplying the all-thread side's participant counts when constructing
producer/consumer groups and computing barrier arrive counts for PipelineType
values UMMA_ASYNC and ASYNC_UMMA (but leave TMA_UMMA unchanged), i.e., when
building groups from producer_warp_ids/consumer_warp_ids in create_pipelines()
multiply the thread counts by pipeline.cluster_scale and use that scaled value
when setting arrive counts for barriers/tx_count_key, or fail fast by adding a
check in create_pipelines() that raises a clear exception if
pipeline.cluster_scale != 1 so callers must handle scaling explicitly.

Comment on lines +41 to +66
# NOTE on CuTe DSL JIT limitations:
# - partition_output(): Returns tensor tuples — CuTe DSL JIT does not
# reliably handle returning tensors from @cute.jit methods.
# - store_tile(): SAFE — takes pre-sliced tensors as arguments, no
# runtime indexing or return values. Used in run() successfully.
# =========================================================================

@cute.jit
def partition_output(
self,
tma_atom_o: cute.CopyAtom,
mO_qdl: cute.Tensor,
sO: cute.Tensor,
block_coord: tuple,
):
"""Partition output global tensor for TMA stores. Returns (tOsO, tOgO)."""
gO_qdl = cute.flat_divide(
mO_qdl, cute.select(self.pv_mma_tiler, mode=[0, 1])
)
gO = gO_qdl[None, None, None, 0, block_coord[2]]
tOsO, tOgO = cute.nvgpu.cpasync.tma_partition(
tma_atom_o, 0, cute.make_layout(1),
cute.group_modes(sO, 0, 2),
cute.group_modes(gO, 0, 2),
)
return tOsO, tOgO
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🌐 Web query:

Does NVIDIA CUTLASS CuTe DSL reliably support returning tensor objects from @cute.jit functions or methods, or should helpers that return tensors stay non-jitted/inlined?

💡 Result:

No, NVIDIA CUTLASS CuTe DSL does not reliably support returning tensor objects from @cute.jit functions or methods. Helpers that return tensors should stay non-jitted/inlined.

Citations:


Remove @cute.jit from partition_output() or refactor to avoid returning tensors.

This method is decorated with @cute.jit but returns a tensor tuple, contradicting the documented limitation directly above it: "CuTe DSL JIT does not reliably handle returning tensors from @cute.jit methods." Since this is exported as a reusable primitive, it violates its own safety contract. Either remove the @cute.jit decorator or restructure to keep tensor returns outside JIT compilation per NVIDIA documentation.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/cute_dsl/attention/roles/epilogue.py` around lines 41 - 66,
partition_output is incorrectly decorated with `@cute.jit` while returning tensor
objects (tOsO, tOgO) which violates the CuTe JIT limitation; either remove the
`@cute.jit` decorator from partition_output so it runs as a normal Python method,
or refactor it to avoid returning tensors by (a) accepting preallocated output
containers/handles and writing into them, or (b) moving the
cute.nvgpu.cpasync.tma_partition call out of the `@cute.jit` function into a
non-jit wrapper (e.g., create partition_output_nonjit that calls
cute.nvgpu.cpasync.tma_partition and returns tensors or change partition_output
to populate passed-in tensor references); update references to partition_output
accordingly so no `@cute.jit` function returns tensors (symbols: partition_output,
tOsO, tOgO, tma_partition, tma_atom_o).

Comment on lines +17 to +71
@dataclass(frozen=True)
class WarpSchedule:
"""Defines warp role assignment and register budgets for attention kernels.

Each field maps directly to C++ CUTLASS's KernelSchedule:
- Warp ID ranges for each role
- Register allocation per role (controls spill/occupancy tradeoff)
- Barrier IDs for CTA sync and TMEM allocation
"""

softmax0_warp_ids: Tuple[int, ...] = (0, 1, 2, 3)
softmax1_warp_ids: Tuple[int, ...] = (4, 5, 6, 7)
correction_warp_ids: Tuple[int, ...] = (8, 9, 10, 11)
mma_warp_id: int = 12
load_warp_id: int = 13
epilogue_warp_id: int = 14
empty_warp_id: int = 15

num_regs_softmax: int = 192
num_regs_correction: int = 96
num_regs_other: int = 32
num_regs_empty: int = 24

threads_per_warp: int = 32
cta_sync_bar_id: int = 0
tmem_alloc_sync_bar_id: int = 1

@property
def all_warp_ids(self) -> Tuple[int, ...]:
return (
*self.softmax0_warp_ids,
*self.softmax1_warp_ids,
*self.correction_warp_ids,
self.mma_warp_id,
self.load_warp_id,
self.epilogue_warp_id,
self.empty_warp_id,
)

@property
def num_warps(self) -> int:
return len(self.all_warp_ids)

@property
def threads_per_cta(self) -> int:
return self.threads_per_warp * self.num_warps

@property
def num_warps_per_warpgroup(self) -> int:
return 4

@property
def softmax_warpgroup_count(self) -> int:
total_softmax_warps = len(self.softmax0_warp_ids) + len(self.softmax1_warp_ids)
return total_softmax_warps // self.num_warps_per_warpgroup
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate custom schedules before deriving CTA sizes.

num_warps, threads_per_cta, and softmax_warpgroup_count all assume the warp ids are unique, contiguous from 0, and that the softmax warps fill whole warpgroups. With a custom WarpSchedule, duplicate/gapped ids or a non-multiple-of-4 softmax set will silently produce the wrong CTA/barrier sizing.

🛠️ Suggested fail-fast validation
 `@dataclass`(frozen=True)
 class WarpSchedule:
@@
     threads_per_warp: int = 32
     cta_sync_bar_id: int = 0
     tmem_alloc_sync_bar_id: int = 1
+
+    def __post_init__(self):
+        all_warp_ids = self.all_warp_ids
+        if len(set(all_warp_ids)) != len(all_warp_ids):
+            raise ValueError("warp ids must be unique across roles")
+        if tuple(sorted(all_warp_ids)) != tuple(range(len(all_warp_ids))):
+            raise ValueError("warp ids must form a contiguous range starting at 0")
+        total_softmax_warps = len(self.softmax0_warp_ids) + len(self.softmax1_warp_ids)
+        if total_softmax_warps % self.num_warps_per_warpgroup != 0:
+            raise ValueError("softmax warps must fill whole warpgroups")
+        if self.cta_sync_bar_id == self.tmem_alloc_sync_bar_id:
+            raise ValueError("barrier ids must be distinct")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/cute_dsl/attention/warp_schedule.py` around lines 17 - 71, Add a
fail-fast validation in WarpSchedule (implement in a __post_init__ method) that
verifies: 1) all_warp_ids (built from softmax0_warp_ids, softmax1_warp_ids,
correction_warp_ids, mma_warp_id, load_warp_id, epilogue_warp_id, empty_warp_id)
contain unique values and form a contiguous range starting at 0 up to
len(all_warp_ids)-1, and 2) the total number of softmax warps
(len(softmax0_warp_ids)+len(softmax1_warp_ids)) is divisible by
num_warps_per_warpgroup; on violation raise ValueError with a clear message
referencing the failing condition so consumers of num_warps, threads_per_cta,
and softmax_warpgroup_count cannot silently compute incorrect sizes.

Comment on lines +159 to +169
self._has_params = self._variant.extra_params is not None
if self._has_params:
ep = self._variant.extra_params.to(torch.float32).to(self._device)
if not ep.is_contiguous():
raise ValueError(
f"AttentionVariant.extra_params must be contiguous, "
f"got strides {ep.stride()} for shape {ep.shape}. "
f"Call .contiguous() before returning from extra_params."
)
self._params_torch = ep
params_cute = from_dlpack(ep, assumed_align=16)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Potential NameError if params_cute referenced when _has_params is False.

params_cute is defined only inside the if self._has_params: block (line 169), but it's referenced unconditionally at line 240. If _has_params is False, accessing params_cute.iterator would raise a NameError.

Looking more closely, line 240 uses a conditional expression params_cute.iterator if self._has_params else None, so the .iterator access is guarded. However, Python evaluates params_cute before the condition, which would still raise NameError if the variable is undefined.

Proposed fix
         if self._has_params:
             ep = self._variant.extra_params.to(torch.float32).to(self._device)
             if not ep.is_contiguous():
                 raise ValueError(
                     f"AttentionVariant.extra_params must be contiguous, "
                     f"got strides {ep.stride()} for shape {ep.shape}. "
                     f"Call .contiguous() before returning from extra_params."
                 )
             self._params_torch = ep
             params_cute = from_dlpack(ep, assumed_align=16)
+        else:
+            params_cute = None

Then at line 240:

-            params_cute.iterator if self._has_params else None,
+            params_cute.iterator if params_cute is not None else None,

Also applies to: 240-240

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/cute_dsl/attention/wrappers/batch_prefill.py` around lines 159 -
169, The NameError risk comes from params_cute being defined only inside the if
self._has_params block yet referenced later; fix by defining params_cute = None
before the if and only assigning it inside the block (where you call
from_dlpack) so later code can safely use the conditional expression
(params_cute.iterator if self._has_params else None); update references
involving self._has_params, _params_torch, and from_dlpack accordingly to rely
on the initialized params_cute variable.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant