Skip to content

Ameyn/gdn decode cutedsl kernel#2498

Open
ameynaik-hub wants to merge 9 commits intoflashinfer-ai:mainfrom
ameynaik-hub:ameyn/gdn-decode-cutedsl-kernel
Open

Ameyn/gdn decode cutedsl kernel#2498
ameynaik-hub wants to merge 9 commits intoflashinfer-ai:mainfrom
ameynaik-hub:ameyn/gdn-decode-cutedsl-kernel

Conversation

@ameynaik-hub
Copy link

@ameynaik-hub ameynaik-hub commented Feb 5, 2026

Implements high-performance Gated Delta Rule linear attention kernel supporting fixed sequence lengths T=1, T=2, T=3, T=4 using NVIDIA CuTe-DSL.

Key features:

  • H state layout: K-last [B, HV, V, K] where K is the contiguous (fastest) dimension
  • L2-normalized Q/K with configurable scale
  • Gated exponential decay via softplus
  • Delta rule updates: v_delta = beta * (v - pred)
  • Async H memory loading with aggressive pipelining
  • BF16 tensors with FP32 compute

Also includes:

  • benchmark_gated_delta_rule.py: Simple benchmark script for measuring kernel perf
  • Updated init.py exports

📌 Description

Implements high-performance Gated Delta Rule linear attention kernel supporting fixed sequence lengths T=1, T=2, T=3, T=4 using NVIDIA CuTe-DSL.

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Summary by CodeRabbit

  • New Features

    • Added CUDA-accelerated Gated Delta Rule kernel with optimized paths for sequence lengths 1–4 and new public API entries gated_delta_rule and GatedDeltaRuleKernel (exported when available).
  • Benchmarks

    • New benchmark for the gated_delta_rule kernel and integration of an improved CuTe-DSL variant into the benchmark suite and CLI, with per‑T summaries and performance reporting.
  • Tests

    • End-to-end tests exercising the improved kernel for T=1–4, availability guards, and updated reference tests to support explicit state storage dtype handling.

Implements high-performance Gated Delta Rule linear attention kernel supporting
fixed sequence lengths T=1, T=2, T=3, T=4 using NVIDIA CuTe-DSL.

Key features:
- H state layout: K-last [B, HV, V, K] where K is the contiguous (fastest) dimension
- Unified kernel architecture: T=2/3/4 share a single compile-time specialized
  kernel via Constexpr dispatch; T=1 uses separate kernel with persistent K optimization
- L2-normalized Q/K with configurable scale
- Gated exponential decay via softplus
- Delta rule updates: v_delta = beta * (v - pred)
- Bank-conflict-free cross-warp reductions
- Async H memory loading with aggressive pipelining
- BF16 tensors with FP32 compute for numerical stability
- GQA (grouped-query attention) support

Also includes:
- benchmark_gated_delta_rule.py: Simple benchmark script for measuring kernel perf
- Updated __init__.py exports

Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @ameynaik-hub, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a highly optimized CUDA kernel for the Gated Delta Rule linear attention mechanism, specifically tailored for decode-phase inference. The implementation, built with NVIDIA CuTe-DSL, provides high performance for fixed sequence lengths of 1, 2, 3, and 4. It incorporates advanced optimizations such as a K-last H state layout, L2-normalized Q/K, gated exponential decay, delta rule updates, and aggressive asynchronous memory pipelining, all while maintaining numerical stability with BF16 tensors and FP32 compute. A new benchmark script is also included to evaluate the kernel's performance.

Highlights

  • New Gated Delta Rule Kernel: Implements a high-performance Gated Delta Rule linear attention kernel for decode-phase inference.
  • Sequence Length Support: Supports fixed sequence lengths T=1, T=2, T=3, T=4 with specialized optimizations.
  • CuTe-DSL Implementation: Leverages NVIDIA CuTe-DSL for optimized CUDA kernel development, ensuring high performance.
  • Optimized H State Layout: Uses a K-last H state layout [B, HV, V, K] where K is the contiguous (fastest) dimension.
  • Q/K Normalization & Gating: Features L2-normalized Q/K with configurable scale and gated exponential decay via softplus.
  • Delta Rule Updates: Incorporates delta rule updates: v_delta = beta * (v - pred).
  • Memory & Pipelining: Utilizes asynchronous H memory loading with aggressive pipelining for improved efficiency.
  • Mixed Precision Compute: Employs BF16 tensors with FP32 compute for numerical stability.
  • Benchmarking Script: Includes a new benchmark script (benchmark_gated_delta_rule.py) for measuring kernel performance.
  • API Export: Exports the new gated_delta_rule function and GatedDeltaRuleKernel class in flashinfer/cute_dsl/__init__.py.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • flashinfer/cute_dsl/init.py
    • Added exports for gated_delta_rule and GatedDeltaRuleKernel to make the new functionality accessible.
  • flashinfer/cute_dsl/benchmark_gated_delta_rule.py
    • Introduced a new Python script for benchmarking the performance of the Gated Delta Rule CuTe-DSL kernel across various batch sizes and sequence lengths.
  • flashinfer/cute_dsl/gated_delta_rule.py
    • Implemented the core Gated Delta Rule linear attention kernel.
    • Includes shared helper functions for common operations.
    • Features a specialized kernel for seqlen=1 with persistent K optimization and aggressive pipelining.
    • Provides a unified kernel for seqlen=2, 3, 4 using compile-time Constexpr specialization.
Activity
  • The author has completed all pre-commit checks, ensuring code quality and style.
  • Tests have been added or updated, and all are passing, indicating the new functionality is working as expected and existing functionality is not regressed.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 5, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds a CuTe-DSL Gated Delta Rule CUDA kernel (seq_len 1–4) with a Python launcher/class, a standalone benchmark, benchmark integration, reference dtype/state handling updates, and tests exercising the new kernel and BF16 state paths.

Changes

Cohort / File(s) Summary
Public API Exports
flashinfer/cute_dsl/__init__.py
Export gated_delta_rule and GatedDeltaRuleKernel under the CuTe-DSL availability guard and update __all__.
Gated Delta Rule Core Implementation
flashinfer/cute_dsl/gated_delta_rule.py
New CuTe-DSL implementation: compile/select kernels for T∈{1,2,3,4}, BF16-storage/FP32-compute, GQA/per-token gating, async loads, bank-conflict-free reductions, kernel caching, Python API gated_delta_rule(...), GatedDeltaRuleKernel, and multiple launch entry points.
Kernel Benchmark
flashinfer/cute_dsl/benchmark_gated_delta_rule.py
New benchmarking script: L2 cache handling, CUDA-event timing, warmup, input synthesis, multi-config runs and formatted reporting.
Benchmark Integration
benchmarks/bench_gdn_decode.py
Integrates “Improved CuTe-DSL” path: wrapper and bench runner, CLI --version improved_cutedsl, T=1..4 support, availability guards, and updated result tables / speedup metrics.
Reference Implementations / dtype handling
tests/gdn/reference_delta_rule.py
Introduce state_dtype parameter; perform compute in FP32 and store states in state_dtype; update signatures, casts, and docstrings across delta/GDN helpers.
Tests for Improved CuTe-DSL
tests/gdn/test_decode_delta_rule.py
Add import/availability flag for improved CuTe-DSL, new test helper and parametrized tests for T=1..4 using BF16 state, conditional execution when kernel unavailable, and test harness updates.

Sequence Diagram(s)

sequenceDiagram
    participant PythonAPI as Python API
    participant KernelCache as Kernel Cache
    participant CuTe as CuTe Compiler
    participant GPU as GPU Kernel
    participant Memory as GPU Memory

    PythonAPI->>PythonAPI: validate inputs (q,k,v,A_log,a,dt_bias,seq_len)
    PythonAPI->>KernelCache: lookup compiled kernel (T, dtypes)
    alt cached
        KernelCache-->>PythonAPI: return cached kernel
    else not cached
        PythonAPI->>CuTe: compile kernel for (T, dtypes)
        CuTe-->>KernelCache: store compiled kernel
        KernelCache-->>PythonAPI: return compiled kernel
    end
    PythonAPI->>GPU: launch kernel on CUDA stream with args
    GPU->>Memory: async load Q,K,V,gates,state blocks
    GPU->>GPU: l2-normalize, gated decay, delta-rule updates, reductions
    GPU->>Memory: write output [B,T,HV,V] and updated state
    GPU-->>PythonAPI: signal completion / return tensor
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested labels

v0.6.2

Suggested reviewers

  • yzh119
  • cyx-6
  • nvmbreughe
  • bkryu
  • jimmyzho
  • jiahanc

Poem

🐇 I hop through kernels, compile with cheer,

Four seq-paths lined up, BF16 held dear.
Async loads whisper, reductions take flight,
Gates and deltas hum through CUDA night.
A tiny rabbit's clap — benchmarks gleam bright.

🚥 Pre-merge checks | ✅ 1 | ❌ 2
❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 73.08% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Title check ❓ Inconclusive The title is vague and lacks specificity about the main change; it uses abbreviations (Ameyn, gdn, cutedsl) without context and does not clearly convey the primary objective. Clarify the title to describe the main change in plain language, e.g., 'Add high-performance Gated Delta Rule kernel for linear attention decoding' or similar.
✅ Passed checks (1 passed)
Check name Status Explanation
Description check ✅ Passed The description includes relevant technical details about the kernel implementation and marks pre-commit checks and tests as complete, meeting most template requirements.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a high-performance Gated Delta Rule linear attention kernel using CuTe-DSL, supporting sequence lengths from 1 to 4. The implementation includes separate, optimized kernels for T=1 and a unified kernel for T=2,3,4, along with a benchmark script. My review focuses on the new kernel implementation in gated_delta_rule.py. I've identified a few areas for improvement: the seqlen=1 kernel contains significant code duplication that could be refactored for better maintainability. The kernel caching strategy could lead to performance issues with dynamic batch sizes due to unnecessary recompilations. Finally, the exported GatedDeltaRuleKernel class appears to be incomplete or unused. Addressing these points will improve the robustness and performance of the new kernel.

stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)

# Check cache
cache_key = (T, B)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The kernel caching mechanism uses the batch size B as part of the cache key. Because the kernel is compiled with concrete tensor shapes (derived from from_dlpack), this will trigger a recompilation for every new batch size encountered at runtime. For applications with dynamic batch sizes, which are common in inference scenarios, this can lead to significant performance degradation due to repeated JIT compilation.

To address this, I recommend using symbolic dimensions for the batch size during compilation to generate a more generic kernel. This can be achieved by using cute.runtime.make_fake_compact_tensor with a cute.sym_int() for the batch dimension, instead of using from_dlpack directly on the input tensors within the compilation path. You can find an example of this pattern in flashinfer/cute_dsl/add_rmsnorm_fp4quant.py. This change would allow you to remove the batch size from the cache key, preventing unnecessary recompilations and improving runtime performance.

Comment on lines 691 to 1085
def gated_delta_rule_decode_kernel_seqlen1(
gQ: cute.Tensor,
gK: cute.Tensor,
gV: cute.Tensor,
ga: cute.Tensor,
gb: cute.Tensor,
gA_log: cute.Tensor,
gdt_bias: cute.Tensor,
gH: cute.Tensor,
gO: cute.Tensor,
scale: cutlass.Float32,
softplus_beta: cutlass.Float32,
softplus_threshold: cutlass.Float32,
eps: cutlass.Float32,
):
"""
Seqlen=1 kernel with persistent K optimization.
OPTIMIZATIONS:
1. PERSISTENT K IN REGISTERS ONLY: K[k_base:k_base+32] kept for entire kernel
Q is reloaded per chunk (lower register pressure than V3)
2. AGGRESSIVE PIPELINING: Load chunks 2 ahead, store during next compute
3. [4,32] CROSS-WARP REDUCTION: Correct lane-preserving reduction
"""
tidx, _, _ = cute.arch.thread_idx()
bidx, _, _ = cute.arch.block_idx()

HV = cutlass.Int32(gV.shape[2])
H = cutlass.Int32(gQ.shape[2])

batch_idx = bidx // HV
value_head_idx = bidx % HV
query_head_idx = value_head_idx // (HV // H)

smem = utils.SmemAllocator()

# Compute gates using shared helper
alpha = ga[(batch_idx, 0, value_head_idx)].to(cutlass.Float32)
beta_raw = gb[(batch_idx, 0, value_head_idx)].to(cutlass.Float32)
A_log_val = gA_log[value_head_idx]
dt_bias_val = gdt_bias[value_head_idx]
g_exp, beta = compute_single_gate(
alpha, beta_raw, dt_bias_val, A_log_val, softplus_beta, softplus_threshold
)

# Allocate SMEM
h_sh_chunk0 = smem.allocate_tensor(
cutlass.BFloat16, cute.make_layout((32, 128), stride=(H_SMEM_STRIDE, 1))
)
h_sh_chunk1 = smem.allocate_tensor(
cutlass.BFloat16, cute.make_layout((32, 128), stride=(H_SMEM_STRIDE, 1))
)
h_sh_chunk2 = smem.allocate_tensor(
cutlass.BFloat16, cute.make_layout((32, 128), stride=(H_SMEM_STRIDE, 1))
)
h_sh_chunk3 = smem.allocate_tensor(
cutlass.BFloat16, cute.make_layout((32, 128), stride=(H_SMEM_STRIDE, 1))
)

q_sh = smem.allocate_tensor(cutlass.Float32, 128)
k_sh = smem.allocate_tensor(cutlass.Float32, 128)

pred_sh = smem.allocate_tensor(cutlass.Float32, cute.make_layout((4, 32)))
out_sh = smem.allocate_tensor(cutlass.Float32, cute.make_layout((4, 32)))

h_global = gH[(batch_idx, value_head_idx, None, None)]

# Launch first 2 async loads
load_h_chunk_async(h_sh_chunk0, h_global, tidx, 0)
nvvm.cp_async_commit_group()
load_h_chunk_async(h_sh_chunk1, h_global, tidx, 32)
nvvm.cp_async_commit_group()

# L2 normalization
q_head = gQ[(batch_idx, 0, query_head_idx, None)]
k_head = gK[(batch_idx, 0, query_head_idx, None)]

warp_idx = tidx // 32
lane_idx = tidx % 32

# Use shared helper for Q/K normalization (only warp 0 does the work)
if warp_idx == 0:
normalize_and_store_qk_to_smem(q_head, k_head, q_sh, k_sh, lane_idx, scale, eps)

cute.arch.sync_threads()

# Load V
v_head = gV[(batch_idx, 0, value_head_idx, None)]
v_sh = smem.allocate_tensor(cutlass.Float32, 128)
v_sh[tidx] = v_head[tidx].to(cutlass.Float32)

# Registers: h_chunk + k_chunk (persistent) + qk_temp (reused for Q)
h_chunk = cute.make_rmem_tensor((32,), cutlass.Float32)
k_chunk = cute.make_rmem_tensor((32,), cutlass.Float32) # PERSISTENT K!
qk_temp = cute.make_rmem_tensor((32,), cutlass.Float32)

k_base = warp_idx * 32

# Load K ONCE - keep for entire kernel
for i in cutlass.range_constexpr(32):
k_chunk[i] = k_sh[k_base + i]

h_out = gH[(batch_idx, value_head_idx, None, None)]
o_head = gO[(batch_idx, 0, value_head_idx, None)]

# ========================================================================
# CHUNK 0
# ========================================================================
nvvm.cp_async_wait_group(1)
cute.arch.sync_threads()

pred = cutlass.Float32(0.0)
pred2 = cutlass.Float32(0.0)
for i in cutlass.range_constexpr(0, 32, 2):
h_chunk[i], h_chunk[i + 1] = cute.arch.fma_packed_f32x2(
src_a=(
h_sh_chunk0[lane_idx, k_base + i].to(cutlass.Float32),
h_sh_chunk0[lane_idx, k_base + i + 1].to(cutlass.Float32),
),
src_b=(g_exp, g_exp),
src_c=(cutlass.Float32(0.0), cutlass.Float32(0.0)),
)
for i in cutlass.range_constexpr(0, 32, 2):
pred, pred2 = cute.arch.fma_packed_f32x2(
src_a=(h_chunk[i], h_chunk[i + 1]),
src_b=(k_chunk[i], k_chunk[i + 1]),
src_c=(pred, pred2),
)
pred = pred + pred2

pred_sh[warp_idx, lane_idx] = pred
cute.arch.sync_threads()
pred_final = (
pred_sh[0, lane_idx]
+ pred_sh[1, lane_idx]
+ pred_sh[2, lane_idx]
+ pred_sh[3, lane_idx]
)

v_val = (v_sh[lane_idx] - pred_final) * beta

for i in cutlass.range_constexpr(0, 32, 2):
h_chunk[i], h_chunk[i + 1] = cute.arch.fma_packed_f32x2(
src_a=(k_chunk[i], k_chunk[i + 1]),
src_b=(v_val, v_val),
src_c=(h_chunk[i], h_chunk[i + 1]),
)

# Load Q for output computation
for i in cutlass.range_constexpr(32):
qk_temp[i] = q_sh[k_base + i]

out = cutlass.Float32(0.0)
out2 = cutlass.Float32(0.0)
for i in cutlass.range_constexpr(0, 32, 2):
out, out2 = cute.arch.fma_packed_f32x2(
src_a=(h_chunk[i], h_chunk[i + 1]),
src_b=(qk_temp[i], qk_temp[i + 1]),
src_c=(out, out2),
)
out = out + out2

out_sh[warp_idx, lane_idx] = out
cute.arch.sync_threads()
out_final = (
out_sh[0, lane_idx]
+ out_sh[1, lane_idx]
+ out_sh[2, lane_idx]
+ out_sh[3, lane_idx]
)

write_h_chunk_to_smem(h_chunk, h_sh_chunk0, lane_idx, k_base)
if warp_idx == 0:
o_head[lane_idx] = out_final.to(cutlass.BFloat16)

# ========================================================================
# CHUNK 1
# ========================================================================
nvvm.cp_async_wait_group(0)
cute.arch.sync_threads()

load_h_chunk_async(h_sh_chunk2, h_global, tidx, 64)
nvvm.cp_async_commit_group()
load_h_chunk_async(h_sh_chunk3, h_global, tidx, 96)
nvvm.cp_async_commit_group()

store_h_smem_to_gmem(h_sh_chunk0, h_out, tidx, 0)

pred = cutlass.Float32(0.0)
pred2 = cutlass.Float32(0.0)
for i in cutlass.range_constexpr(0, 32, 2):
h_chunk[i], h_chunk[i + 1] = cute.arch.fma_packed_f32x2(
src_a=(
h_sh_chunk1[lane_idx, k_base + i].to(cutlass.Float32),
h_sh_chunk1[lane_idx, k_base + i + 1].to(cutlass.Float32),
),
src_b=(g_exp, g_exp),
src_c=(cutlass.Float32(0.0), cutlass.Float32(0.0)),
)
for i in cutlass.range_constexpr(0, 32, 2):
pred, pred2 = cute.arch.fma_packed_f32x2(
src_a=(h_chunk[i], h_chunk[i + 1]),
src_b=(k_chunk[i], k_chunk[i + 1]),
src_c=(pred, pred2),
)
pred = pred + pred2

pred_sh[warp_idx, lane_idx] = pred
cute.arch.sync_threads()
pred_final = (
pred_sh[0, lane_idx]
+ pred_sh[1, lane_idx]
+ pred_sh[2, lane_idx]
+ pred_sh[3, lane_idx]
)

v_val = (v_sh[32 + lane_idx] - pred_final) * beta

for i in cutlass.range_constexpr(0, 32, 2):
h_chunk[i], h_chunk[i + 1] = cute.arch.fma_packed_f32x2(
src_a=(k_chunk[i], k_chunk[i + 1]),
src_b=(v_val, v_val),
src_c=(h_chunk[i], h_chunk[i + 1]),
)

for i in cutlass.range_constexpr(32):
qk_temp[i] = q_sh[k_base + i]

out = cutlass.Float32(0.0)
out2 = cutlass.Float32(0.0)
for i in cutlass.range_constexpr(0, 32, 2):
out, out2 = cute.arch.fma_packed_f32x2(
src_a=(h_chunk[i], h_chunk[i + 1]),
src_b=(qk_temp[i], qk_temp[i + 1]),
src_c=(out, out2),
)
out = out + out2

out_sh[warp_idx, lane_idx] = out
cute.arch.sync_threads()
out_final = (
out_sh[0, lane_idx]
+ out_sh[1, lane_idx]
+ out_sh[2, lane_idx]
+ out_sh[3, lane_idx]
)

write_h_chunk_to_smem(h_chunk, h_sh_chunk1, lane_idx, k_base)
if warp_idx == 0:
o_head[32 + lane_idx] = out_final.to(cutlass.BFloat16)

# ========================================================================
# CHUNK 2
# ========================================================================
nvvm.cp_async_wait_group(1)
cute.arch.sync_threads()

store_h_smem_to_gmem(h_sh_chunk1, h_out, tidx, 32)

pred = cutlass.Float32(0.0)
pred2 = cutlass.Float32(0.0)
for i in cutlass.range_constexpr(0, 32, 2):
h_chunk[i], h_chunk[i + 1] = cute.arch.fma_packed_f32x2(
src_a=(
h_sh_chunk2[lane_idx, k_base + i].to(cutlass.Float32),
h_sh_chunk2[lane_idx, k_base + i + 1].to(cutlass.Float32),
),
src_b=(g_exp, g_exp),
src_c=(cutlass.Float32(0.0), cutlass.Float32(0.0)),
)
for i in cutlass.range_constexpr(0, 32, 2):
pred, pred2 = cute.arch.fma_packed_f32x2(
src_a=(h_chunk[i], h_chunk[i + 1]),
src_b=(k_chunk[i], k_chunk[i + 1]),
src_c=(pred, pred2),
)
pred = pred + pred2

pred_sh[warp_idx, lane_idx] = pred
cute.arch.sync_threads()
pred_final = (
pred_sh[0, lane_idx]
+ pred_sh[1, lane_idx]
+ pred_sh[2, lane_idx]
+ pred_sh[3, lane_idx]
)

v_val = (v_sh[64 + lane_idx] - pred_final) * beta

for i in cutlass.range_constexpr(0, 32, 2):
h_chunk[i], h_chunk[i + 1] = cute.arch.fma_packed_f32x2(
src_a=(k_chunk[i], k_chunk[i + 1]),
src_b=(v_val, v_val),
src_c=(h_chunk[i], h_chunk[i + 1]),
)

for i in cutlass.range_constexpr(32):
qk_temp[i] = q_sh[k_base + i]

out = cutlass.Float32(0.0)
out2 = cutlass.Float32(0.0)
for i in cutlass.range_constexpr(0, 32, 2):
out, out2 = cute.arch.fma_packed_f32x2(
src_a=(h_chunk[i], h_chunk[i + 1]),
src_b=(qk_temp[i], qk_temp[i + 1]),
src_c=(out, out2),
)
out = out + out2

out_sh[warp_idx, lane_idx] = out
cute.arch.sync_threads()
out_final = (
out_sh[0, lane_idx]
+ out_sh[1, lane_idx]
+ out_sh[2, lane_idx]
+ out_sh[3, lane_idx]
)

write_h_chunk_to_smem(h_chunk, h_sh_chunk2, lane_idx, k_base)
if warp_idx == 0:
o_head[64 + lane_idx] = out_final.to(cutlass.BFloat16)

# ========================================================================
# CHUNK 3
# ========================================================================
nvvm.cp_async_wait_group(0)
cute.arch.sync_threads()

store_h_smem_to_gmem(h_sh_chunk2, h_out, tidx, 64)

pred = cutlass.Float32(0.0)
pred2 = cutlass.Float32(0.0)
for i in cutlass.range_constexpr(0, 32, 2):
h_chunk[i], h_chunk[i + 1] = cute.arch.fma_packed_f32x2(
src_a=(
h_sh_chunk3[lane_idx, k_base + i].to(cutlass.Float32),
h_sh_chunk3[lane_idx, k_base + i + 1].to(cutlass.Float32),
),
src_b=(g_exp, g_exp),
src_c=(cutlass.Float32(0.0), cutlass.Float32(0.0)),
)
for i in cutlass.range_constexpr(0, 32, 2):
pred, pred2 = cute.arch.fma_packed_f32x2(
src_a=(h_chunk[i], h_chunk[i + 1]),
src_b=(k_chunk[i], k_chunk[i + 1]),
src_c=(pred, pred2),
)
pred = pred + pred2

pred_sh[warp_idx, lane_idx] = pred
cute.arch.sync_threads()
pred_final = (
pred_sh[0, lane_idx]
+ pred_sh[1, lane_idx]
+ pred_sh[2, lane_idx]
+ pred_sh[3, lane_idx]
)

v_val = (v_sh[96 + lane_idx] - pred_final) * beta

for i in cutlass.range_constexpr(0, 32, 2):
h_chunk[i], h_chunk[i + 1] = cute.arch.fma_packed_f32x2(
src_a=(k_chunk[i], k_chunk[i + 1]),
src_b=(v_val, v_val),
src_c=(h_chunk[i], h_chunk[i + 1]),
)

for i in cutlass.range_constexpr(32):
qk_temp[i] = q_sh[k_base + i]

out = cutlass.Float32(0.0)
out2 = cutlass.Float32(0.0)
for i in cutlass.range_constexpr(0, 32, 2):
out, out2 = cute.arch.fma_packed_f32x2(
src_a=(h_chunk[i], h_chunk[i + 1]),
src_b=(qk_temp[i], qk_temp[i + 1]),
src_c=(out, out2),
)
out = out + out2

out_sh[warp_idx, lane_idx] = out
cute.arch.sync_threads()
out_final = (
out_sh[0, lane_idx]
+ out_sh[1, lane_idx]
+ out_sh[2, lane_idx]
+ out_sh[3, lane_idx]
)

write_h_chunk_to_smem(h_chunk, h_sh_chunk3, lane_idx, k_base)
if warp_idx == 0:
o_head[96 + lane_idx] = out_final.to(cutlass.BFloat16)

cute.arch.sync_threads()
store_h_smem_to_gmem(h_sh_chunk3, h_out, tidx, 96)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for processing each of the four V-dimension chunks is nearly identical, leading to significant code duplication within this kernel. This makes the code difficult to read and maintain, as any change in the chunk processing logic needs to be manually replicated four times.

Consider refactoring the duplicated computation into a @cute.jit helper function. This function could encapsulate the logic for decaying the state, computing predictions, updating the state, and calculating the output for a single chunk. The main kernel would then manage the pipelined loading and storing of data while calling this helper function for each chunk. This would greatly improve code clarity and maintainability.

Comment on lines +1647 to +1679
class GatedDeltaRuleKernel:
"""
Gated Delta Rule Kernel for linear attention decode.

This kernel implements the Gated Delta Rule mechanism supporting sequence
lengths T=1, T=2, T=3, T=4 with optimized CUDA implementations.

Key features:
- T=1: Persistent K in registers with aggressive pipelining
- T=2/3/4: Unified kernel with compile-time Constexpr specialization
- L2-normalized Q/K with configurable scale
- Gated exponential decay via softplus
- Bank-conflict-free cross-warp reductions
- Async H memory loading

Args:
seq_len: Sequence length (1, 2, 3, or 4)
"""

def __init__(self, seq_len: int):
assert seq_len in [1, 2, 3, 4], f"Supported seq_len: 1,2,3,4, got {seq_len}"
self.seq_len = seq_len
self._compiled_kernel = None

def _get_launch_fn(self):
if self.seq_len == 1:
return gated_delta_rule_launch_seqlen1
elif self.seq_len == 2:
return gated_delta_rule_launch_seqlen2
elif self.seq_len == 3:
return gated_delta_rule_launch_seqlen3
else:
return gated_delta_rule_launch_seqlen4
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The GatedDeltaRuleKernel class is defined and exported as part of the public API, but it appears to be incomplete and is not used by the main gated_delta_rule function. The _compiled_kernel member is initialized to None and is never assigned a compiled kernel, and there is no method to execute it. The functional entry point gated_delta_rule implements its own caching and launch logic.

If this class is intended for future use, it should be fully implemented. If it is obsolete or a work-in-progress, it should either be completed or removed to avoid confusing users of the library.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Fix all issues with AI agents
In `@flashinfer/cute_dsl/gated_delta_rule.py`:
- Around line 1773-1774: The cache_key currently set to (T, B) is too narrow and
can cause incorrect kernel reuse; update the cache key creation (variable
cache_key) to include the relevant tensor dimensions such as H, HV, K, and V
(the head count and per-head dims) so it uniquely identifies tensor shapes used
by the kernel—derive these from the involved tensors (the q/k/v/proj shapes or
whatever variables represent heads and head sizes) and include them alongside T
and B when building cache_key to prevent shape-mismatch cache hits.
- Around line 1666-1680: GatedDeltaRuleKernel currently only stores seq_len and
_compiled_kernel and provides _get_launch_fn but is never used by
gated_delta_rule or _compiled_kernels; either finish the class by adding a
public execution API (e.g., an __call__ or execute method that compiles/looks up
the kernel into _compiled_kernel, uses _get_launch_fn to obtain the launch
function and runs it with the same signature as the module-level implementation)
or remove the class from the public surface and migrate any kernel-caching logic
into the existing module-level _compiled_kernels paths; reference the class name
GatedDeltaRuleKernel, its attribute _compiled_kernel, method _get_launch_fn, and
the public gated_delta_rule/_compiled_kernels cache when making the change.
- Around line 1746-1749: The code accesses q, k, v, b and initial_state_source
without null checks even though their annotations allow None; add explicit
validation at the start of the function (e.g., in the function containing the
lines that reference q.shape and v.shape) to raise a clear ValueError if any of
q, k, v, b, or initial_state_source is None (or alternatively update the
function signature to remove Optional typing for these parameters), and ensure
any downstream use of from_dlpack(initial_state_source, ...) only occurs after
confirming initial_state_source is not None; reference the variables q, k, v, b
and initial_state_source and the shapes accessed (q.shape, v.shape[2],
v.shape[3]) so the checks are placed before those accesses.
- Around line 1153-1165: Unconditional shared-memory allocations q_sh2, k_sh2,
q_sh3, k_sh3 and the extra v_sh* buffers are always created even when NUM_TOKENS
is smaller; wrap those smem.allocate_tensor(...) calls in compile-time guards so
unused buffers are eliminated. Specifically, change the allocations for
q_sh2/k_sh2 to be inside an if constexpr (NUM_TOKENS >= 3) block and the
allocations for q_sh3/k_sh3 to be inside an if constexpr (NUM_TOKENS == 4) block
(and similarly guard v_sh2/v_sh3 as appropriate), keeping the same
smem.allocate_tensor(cutlass.Float32, 128) calls and names so later code (gate
computations) still references the same identifiers when compiled in. Ensure the
guards use the compile-time NUM_TOKENS parameter so the compiler can drop unused
allocations.
🧹 Nitpick comments (5)
flashinfer/cute_dsl/gated_delta_rule.py (3)

118-136: Potential numerical stability concern in compute_single_gate.

The softplus implementation handles large positive values but not large negative values of beta_x. When beta_x is very negative, cute.math.exp(beta_x) approaches zero, which is fine numerically. However, the sigmoid computation at line 134 could have issues with large positive beta_raw values where cute.math.exp(-beta_raw) underflows to 0, resulting in beta = 1.0 (acceptable), but large negative beta_raw causes exp(-beta_raw) to overflow.

Consider adding a symmetric threshold check for the sigmoid similar to the softplus handling.

💡 Optional: Add numerical safeguard for sigmoid
     g = -cute.math.exp(A_log_val) * softplus_x
     g_exp = cute.math.exp(g)
-    beta = cutlass.Float32(1.0) / (cutlass.Float32(1.0) + cute.math.exp(-beta_raw))
+    # Numerically stable sigmoid
+    if beta_raw >= cutlass.Float32(0.0):
+        beta = cutlass.Float32(1.0) / (cutlass.Float32(1.0) + cute.math.exp(-beta_raw))
+    else:
+        exp_beta = cute.math.exp(beta_raw)
+        beta = exp_beta / (cutlass.Float32(1.0) + exp_beta)
     return g_exp, beta

306-354: Consider removing unused o_head parameter.

The o_head parameter is passed but never used in process_first_token. The first token's output is returned and stored later during subsequent token processing. Removing it would clarify the function's contract.

♻️ Remove unused parameter
 `@cute.jit`
 def process_first_token(
     h_sh_chunk_curr,
     h_chunk,
     kq_chunk,
     k_sh,
     q_sh,
     v_sh,
     reduce_sh,
-    o_head,
     g_exp,
     beta,
     v_offset,
     pred_slot,
     warp_idx,
     lane_idx,
     k_base,
 ):

Note: This would require updating all call sites in process_vchunk_unified_234.


1746-1746: Unpacked variable H is unused.

The variable H is extracted from q.shape but never used in the function. Consider prefixing with underscore to indicate intentional non-use.

♻️ Prefix unused variable
-    B, T, H, K = q.shape
+    B, T, _H, K = q.shape
flashinfer/cute_dsl/benchmark_gated_delta_rule.py (2)

82-82: Import path may fail depending on execution context.

The relative import from gated_delta_rule import gated_delta_rule assumes the script is run from within the flashinfer/cute_dsl/ directory or that directory is in sys.path. This will fail if run from the repository root or via python -m.

Consider using absolute import for robustness.

♻️ Use absolute import
 def main():
-    from gated_delta_rule import gated_delta_rule
+    from flashinfer.cute_dsl.gated_delta_rule import gated_delta_rule

116-127: Loop variable capture in closure - latent bug risk.

The run_kernel function captures inputs and state by reference, not by value. While this works correctly here because the closure is consumed immediately via benchmark(), it's a latent bug that would manifest if the code were refactored (e.g., collecting closures to run later).

♻️ Bind loop variables explicitly
-            def run_kernel():
+            def run_kernel(inputs=inputs, state=state):
                 return gated_delta_rule(
                     A_log=inputs["A_log"],
                     a=inputs["a"],
                     dt_bias=inputs["dt_bias"],
                     q=inputs["q"],
                     k=inputs["k"],
                     v=inputs["v"],
                     b=inputs["b"],
                     initial_state_source=state,
                     scale=inputs["scale"],
                 )

Comment on lines +1153 to +1165
# Q/K buffers for token 2 (only for NUM_TOKENS >= 3)
q_sh2 = smem.allocate_tensor(cutlass.Float32, 128)
k_sh2 = smem.allocate_tensor(cutlass.Float32, 128)

# Q/K buffers for token 3 (only for NUM_TOKENS=4)
q_sh3 = smem.allocate_tensor(cutlass.Float32, 128)
k_sh3 = smem.allocate_tensor(cutlass.Float32, 128)

# V buffers
v_sh0 = smem.allocate_tensor(cutlass.Float32, 128)
v_sh1 = smem.allocate_tensor(cutlass.Float32, 128)
v_sh2 = smem.allocate_tensor(cutlass.Float32, 128)
v_sh3 = smem.allocate_tensor(cutlass.Float32, 128)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

cd /tmp && git ls-files flashinfer/cute_dsl/gated_delta_rule.py

Repository: flashinfer-ai/flashinfer

Length of output: 138


🏁 Script executed:

wc -l flashinfer/cute_dsl/gated_delta_rule.py

Repository: flashinfer-ai/flashinfer

Length of output: 112


🏁 Script executed:

sed -n '1140,1180p' flashinfer/cute_dsl/gated_delta_rule.py

Repository: flashinfer-ai/flashinfer

Length of output: 1718


🏁 Script executed:

# Search for NUM_TOKENS usage and conditional patterns around this section
sed -n '1100,1200p' flashinfer/cute_dsl/gated_delta_rule.py | head -100

Repository: flashinfer-ai/flashinfer

Length of output: 3977


🏁 Script executed:

# Look for if statements guarding buffer allocation
rg -n "if.*NUM_TOKENS" flashinfer/cute_dsl/gated_delta_rule.py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 357


🌐 Web query:

CuTe DSL Constexpr shared memory allocation dead code elimination

💡 Result:

In CuTe / CuTeDSL (CUTLASS), whether “shared-memory allocation” and the associated copies/loops disappear is mostly a question of whether the compiler can prove things at compile time (i.e., via types and constexpr / if constexpr) and whether the backend can safely apply dead-store / dead-code elimination across the generated code.

1) “constexpr shared memory allocation” in practice (CuTe/CUTLASS)

  • CuTe is designed so that tile shapes, layouts, and smem footprints are carried in the type system; when those are compile-time constants, the compiler can fully specialize code (unroll, constant-fold addresses, remove unused objects/paths).
  • Practically, that means: prefer CuTe’s compile-time control-flow (if constexpr, cute::static_for, type-based shapes/layouts) rather than runtime loops/conditions if you expect code (and sometimes even smem objects) to vanish.
  • CUTLASS’ own guidelines explicitly lean on if constexpr-based exhaustiveness and “unreachable” patterns to ensure unwanted branches don’t survive compilation (and to avoid “missing return” / unreachable-path issues) [1].

2) Why you can see “dead-looking” stores (local or shared) even at -O3

What you’re describing matches a real pattern people have observed with CuTeDSL-generated kernels: stores to .local (stack) that appear unused even though the values are consumed directly from registers in nearby FFMA instructions.

A recent CuTeDSL example shows exactly this: STL.128 stores to local memory immediately followed by FFMA using the same registers, with no obvious reloads in between [2]. A plausible cause suggested in-thread is conservative aliasing / overlap concerns (compiler can’t prove the store is irrelevant because it can’t fully prove what may alias what) [2].

Also, NVIDIA engineers have long emphasized an important split:

  • PTX is intermediate; optimizations (including dead-code elimination) may happen later in ptxas, and you should judge final behavior by SASS [3].
  • However, in your case you’re already looking at SASS and still seeing the stores, which usually means one of:
    • the backend is being conservative (aliasing, lifetime issues),
    • the store participates in some internal lowering pattern (e.g., preserving values across a region/edge the optimizer doesn’t fully reason about),
    • or it’s simply a backend missed-optimization / bug (these have been reported historically for redundant local loads/stores) [4].

3) Concrete knobs that usually decide whether it optimizes away

  • Make indices compile-time: if a loop induction variable controls fragment selection/addressing, switching to cute::static_for / unrolling often removes unexpected local traffic (your own observation aligns with this).
  • Prefer if constexpr for optional smem paths so unused code paths don’t exist in the instantiated kernel (CUTLASS recommends these patterns) [1].
  • Avoid debug/per-function compilation modes when judging DCE (e.g., -G, device debug) because they commonly inhibit optimization (this is a common explanation for “dead” locals in PTX/SASS discussions) [3].
  • If aliasing is plausible, try enforcing non-aliasing where applicable (CUDA C++: __restrict__ on raw pointers; with CuTe views it’s trickier, but the same principle applies). The forum reply explicitly raises aliasing as a reason the compiler may keep the stores [2].

If you want, paste a minimal CuTeDSL repro (kernel + compile flags + GPU arch + CUDA version) and I’ll point out exactly which parts prevent compile-time proving (and what to change so smem/lmem traffic is actually eliminated).

Sources: CUTLASS programming guidelines [1]; NVIDIA forum thread showing CuTeDSL dead-looking STL.128 in SASS and aliasing hypothesis [2]; NVIDIA forum explanation that PTX is intermediate and DCE happens in ptxas / check SASS and debug builds inhibit this [3]; historical report of redundant lmem loads/stores as compiler/backend artifact/bug [4].


Unconditional SMEM allocation for unused token buffers wastes memory for smaller batch sizes.

The Q/K/V shared memory buffers are allocated unconditionally at compile time for all four tokens, even when NUM_TOKENS=2 or 3. Comments indicate buffers are "only for NUM_TOKENS >= 3" and "only for NUM_TOKENS=4", but no compile-time conditionals guard these allocations. The subsequent gate computations at lines 1195+ use if NUM_TOKENS runtime checks, which cannot eliminate the earlier allocations. For NUM_TOKENS=2, this means allocating q_sh2, k_sh2, q_sh3, k_sh3, and partial v_sh buffers unnecessarily.

While this trade-off is acceptable for a unified kernel (avoiding code duplication and supporting multiple sequence lengths), the SMEM footprint could be reduced by wrapping allocations in if constexpr guards tied to the NUM_TOKENS parameter, allowing true compile-time elimination rather than runtime conditionals.

🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/gated_delta_rule.py` around lines 1153 - 1165,
Unconditional shared-memory allocations q_sh2, k_sh2, q_sh3, k_sh3 and the extra
v_sh* buffers are always created even when NUM_TOKENS is smaller; wrap those
smem.allocate_tensor(...) calls in compile-time guards so unused buffers are
eliminated. Specifically, change the allocations for q_sh2/k_sh2 to be inside an
if constexpr (NUM_TOKENS >= 3) block and the allocations for q_sh3/k_sh3 to be
inside an if constexpr (NUM_TOKENS == 4) block (and similarly guard v_sh2/v_sh3
as appropriate), keeping the same smem.allocate_tensor(cutlass.Float32, 128)
calls and names so later code (gate computations) still references the same
identifiers when compiled in. Ensure the guards use the compile-time NUM_TOKENS
parameter so the compiler can drop unused allocations.

Comment on lines +1666 to +1680
def __init__(self, seq_len: int):
assert seq_len in [1, 2, 3, 4], f"Supported seq_len: 1,2,3,4, got {seq_len}"
self.seq_len = seq_len
self._compiled_kernel = None

def _get_launch_fn(self):
if self.seq_len == 1:
return gated_delta_rule_launch_seqlen1
elif self.seq_len == 2:
return gated_delta_rule_launch_seqlen2
elif self.seq_len == 3:
return gated_delta_rule_launch_seqlen3
else:
return gated_delta_rule_launch_seqlen4

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# Check if GatedDeltaRuleKernel is used anywhere in the codebase
rg -n 'GatedDeltaRuleKernel' --type=py -A 2 -B 2

Repository: flashinfer-ai/flashinfer

Length of output: 900


🏁 Script executed:

# Read the full GatedDeltaRuleKernel class implementation
cat -n flashinfer/cute_dsl/gated_delta_rule.py | sed -n '1645,1750p'

Repository: flashinfer-ai/flashinfer

Length of output: 4690


🏁 Script executed:

# Check for the gated_delta_rule function and its implementation
rg -n 'def gated_delta_rule' --type=py -A 30

Repository: flashinfer-ai/flashinfer

Length of output: 21951


🏁 Script executed:

# Check if gated_delta_rule function uses GatedDeltaRuleKernel internally
sed -n '1689,1800p' flashinfer/cute_dsl/gated_delta_rule.py | grep -i 'GatedDeltaRuleKernel\|_compiled_kernel'

Repository: flashinfer-ai/flashinfer

Length of output: 192


🏁 Script executed:

# Search for any usage of _compiled_kernel attribute across the file
rg -n '_compiled_kernel' flashinfer/cute_dsl/gated_delta_rule.py

Repository: flashinfer-ai/flashinfer

Length of output: 366


GatedDeltaRuleKernel class is incomplete and unused.

The class initializes _compiled_kernel = None but never uses it, has no execution method (no __call__, execute, or similar), and is not instantiated by the public gated_delta_rule function. The module uses a module-level _compiled_kernels cache instead. Consider either completing this class with an execution interface or removing it from the public API if it's not intended for direct use.

🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/gated_delta_rule.py` around lines 1666 - 1680,
GatedDeltaRuleKernel currently only stores seq_len and _compiled_kernel and
provides _get_launch_fn but is never used by gated_delta_rule or
_compiled_kernels; either finish the class by adding a public execution API
(e.g., an __call__ or execute method that compiles/looks up the kernel into
_compiled_kernel, uses _get_launch_fn to obtain the launch function and runs it
with the same signature as the module-level implementation) or remove the class
from the public surface and migrate any kernel-caching logic into the existing
module-level _compiled_kernels paths; reference the class name
GatedDeltaRuleKernel, its attribute _compiled_kernel, method _get_launch_fn, and
the public gated_delta_rule/_compiled_kernels cache when making the change.

Comment on lines +1746 to +1749
B, T, H, K = q.shape
assert T in [1, 2, 3, 4], f"Supported T=1,2,3,4, got T={T}"
HV = v.shape[2]
V = v.shape[3]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Missing input validation for required parameters.

The function signature declares q, k, v, b, and initial_state_source as Optional, but the code unconditionally accesses them (e.g., q.shape, v.shape[2], from_dlpack(initial_state_source, ...)). If any are None, this will raise unhelpful errors.

Either validate these inputs explicitly or change the type hints to reflect they're actually required.

🛡️ Proposed fix: Add validation
     global _compiled_kernels
 
+    if q is None or k is None or v is None or b is None or initial_state_source is None:
+        raise ValueError("q, k, v, b, and initial_state_source are required parameters")
+
     B, T, H, K = q.shape
     assert T in [1, 2, 3, 4], f"Supported T=1,2,3,4, got T={T}"

Or change signatures to non-Optional:

-    q: Optional[torch.Tensor] = None,
-    k: Optional[torch.Tensor] = None,
-    v: Optional[torch.Tensor] = None,
-    b: Optional[torch.Tensor] = None,
-    initial_state_source: Optional[torch.Tensor] = None,
+    q: torch.Tensor,
+    k: torch.Tensor,
+    v: torch.Tensor,
+    b: torch.Tensor,
+    initial_state_source: torch.Tensor,
🧰 Tools
🪛 Ruff (0.14.14)

[warning] 1746-1746: Unpacked variable H is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/gated_delta_rule.py` around lines 1746 - 1749, The code
accesses q, k, v, b and initial_state_source without null checks even though
their annotations allow None; add explicit validation at the start of the
function (e.g., in the function containing the lines that reference q.shape and
v.shape) to raise a clear ValueError if any of q, k, v, b, or
initial_state_source is None (or alternatively update the function signature to
remove Optional typing for these parameters), and ensure any downstream use of
from_dlpack(initial_state_source, ...) only occurs after confirming
initial_state_source is not None; reference the variables q, k, v, b and
initial_state_source and the shapes accessed (q.shape, v.shape[2], v.shape[3])
so the checks are placed before those accesses.

Comment on lines +1773 to +1774
# Check cache
cache_key = (T, B)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Cache key may cause unnecessary recompilations or cache misses.

The cache key (T, B) doesn't account for tensor shapes beyond batch size and sequence length. If H, HV, K, or V change between calls with the same (T, B), the cached kernel might be invoked with incompatible tensor shapes, potentially causing silent correctness issues or crashes.

Consider including relevant shape dimensions in the cache key.

🐛 Proposed fix
     # Check cache
-    cache_key = (T, B)
+    cache_key = (T, B, H, HV, K, V)
     if cache_key not in _compiled_kernels:
🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/gated_delta_rule.py` around lines 1773 - 1774, The
cache_key currently set to (T, B) is too narrow and can cause incorrect kernel
reuse; update the cache key creation (variable cache_key) to include the
relevant tensor dimensions such as H, HV, K, and V (the head count and per-head
dims) so it uniquely identifies tensor shapes used by the kernel—derive these
from the involved tensors (the q/k/v/proj shapes or whatever variables represent
heads and head sizes) and include them alongside T and B when building cache_key
to prevent shape-mismatch cache hits.

@yzh119
Copy link
Collaborator

yzh119 commented Feb 5, 2026

@ameynaik-hub would you mind comparing with https://github.com/flashinfer-ai/flashinfer/blob/main/flashinfer/gdn_decode.py ? It's better to unify the file location and interface.

@ameynaik-hub
Copy link
Author

@yzh119 compared benchmark here #2493 (comment)

https://github.com/flashinfer-ai/flashinfer/blob/main/flashinfer/gdn_decode.py kernel supports
K-major and V-major h state layout. Also I believe it supports seqlen > 4.

on other hand, this kernel only support V-major (K contiguous , k -last dim) h state layout which is used for qwen models and fixed seqlen <=4 only.

Any suggestions on how to merge?

@vadiklyutiy
Copy link

@ameynaik-hub would you mind comparing with https://github.com/flashinfer-ai/flashinfer/blob/main/flashinfer/gdn_decode.py ? It's better to unify the file location and interface.

and test as well

@vadiklyutiy
Copy link

@yzh119 compared benchmark here #2493 (comment)

I think it's worth to copy results here

@vadiklyutiy
Copy link

New kernel is always better than old with pretranspose except batch=1 that we can excuse. Does it make sense to keep old pretranspose varian?

@aditya-narayan5
Copy link

aditya-narayan5 commented Feb 5, 2026

@ameynaik-hub I might have overlooked details during reading but does anyone have a source for this:

V-major (K contiguous , k -last dim) h state layout which is used for qwen models

@ameynaik-hub
Copy link
Author

ameynaik-hub commented Feb 5, 2026

New kernel is always better than old with pretranspose except batch=1 that we can excuse. Does it make sense to keep old pretranspose varian?

@vadiklyutiy since it has support for T>4 (seqlen) and also support for v-last dim of h I thought it can be kept as is?

@ameynaik-hub
Copy link
Author

@vadiklyutiy
Copy link

New kernel is always better than old with pretranspose except batch=1 that we can excuse. Does it make sense to keep old pretranspose varian?

@vadiklyutiy since it has support for T>4 (seqlen) and also support for v-last dim of h I thought it can be kept as is?

As I remember there were several different interfaces/functions. I meant that processes case seq_len=1, k-last.
Of course I have not tried to propose unique interfaces, just want to avoid duplication

@yzh119
Copy link
Collaborator

yzh119 commented Feb 5, 2026

@ameynaik-hub can you reuse the gdn decode interface and set your implementation as default when we use k-layout (V-major) layout and T <=4?

Copy link
Collaborator

@kahyunnam kahyunnam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ameynaik-hub this LGTM as the initial PR to just get the kernel into the Flashinfer codebase, approved. (still need a codeowner approval for this directory to merge ... maybe @yzh119 or @bkryu ). We can follow up with another PR for some work for the API integration + Flashinfer testing/benchmarking. Let me know if you have the bandwidth to work on this integration; if not, I can probably work on it next week.

@yzh119
Copy link
Collaborator

yzh119 commented Feb 5, 2026

Hi @kahyunnam , I encourage not to push more code to flashinfer/cute_dsl anymore as we plan to categorize modules by functionalities, not sources. Other PRs are starting to remove codes out of flashinfer.cute_dsl.

The required work to unify the existing gdn decode and this PR should be minimal, I can help if you need any assitance with it.



def benchmark(
func, num_iterations=100, n_warmup=10, flush_l2=True, use_dummy_matmul=True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just curious why that wasn't addressed and pointed in #2370 but was here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The benchmarking script in #2370 should use bench_gpu_time as well. I didn't notice that when reviewing that PR (it's my bad) and it was fixed in #2405.

Let's get it right in one shot in this PR.

else: # T == 4
launch_fn = gated_delta_rule_launch_seqlen4

_compiled_kernels[cache_key] = cute.compile(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please enable tvm-ffi with options="--enable-tvm-ffi",, otherwise the host side overhead of converting torch.tensor to dlpack would be non-neglibile.

Reference can be found at #2279

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks will do.

@ameynaik-hub
Copy link
Author

@ameynaik-hub can you reuse the gdn decode interface and set your implementation as default when we use k-layout (V-major) layout and T <=4?

@yzh119 yeah I can do that.

I noticed that my kernel utilizes bf16 inputs for the state h, which appears to be compatible with sglang and other models. However, it seems that the default state for h is fp32. I am investigating the possibility of optimizing this configuration for this specific case.

cc: @vadiklyutiy @kahyunnam

@yzh119
Copy link
Collaborator

yzh119 commented Feb 6, 2026

cc @guangyunh-nv (gdn prefill's author) do you have any insights on whether it's okay to use f16 state?

@guangyunh-nv
Copy link
Collaborator

FP16 state might not be a good choice due to the dynamic range limitation. The range is not guaranteed and purely depends on the model activation dynamics. So this purely depends on the model behavior.

Datapoint:
Previously with MiniMax M1 (lighting attention, similar to GDN but gating is not data dependent), the model only runs with BF16 due to its dynamical range grows exponentially somehow and then saturate.

BF16 maybe safer, but the accuracy would be a problem. If algorithm side decide that this is tolerable, then I think we should support it.

I think what we can do at the moment is to tolerate all kinds of state input, and learn from the production environment feedback :)

@xutizhou
Copy link
Contributor

xutizhou commented Feb 6, 2026

Have you tested the end-to-end accuracy with FP16 SSM state? Does the performance remain unchanged on common benchmarks such as MMLU, GSM8K, etc.?

@vadiklyutiy
Copy link

I think what we can do at the moment is to tolerate all kinds of state input, and learn from the production environment feedback :)

I don't think this is right approach. Imagine car manufacturer introduce new self driving car and say let learn from road environment. Don't think a lot of people would be happy.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In `@flashinfer/cute_dsl/gated_delta_rule.py`:
- Around line 1695-1709: The parameter use_qk_l2norm_in_kernel on
gated_delta_rule is currently ignored because the kernel always normalizes Q/K;
add an explicit guard at the top of gated_delta_rule to either (a) implement a
non-normalizing code path for Q/K or (b) fail fast — e.g., if not
use_qk_l2norm_in_kernel: raise ValueError("use_qk_l2norm_in_kernel=False is not
supported: kernel currently always L2-normalizes q/k"), so callers get explicit
behavior instead of silent incorrect semantics; reference the gated_delta_rule
function and the use_qk_l2norm_in_kernel parameter when adding the check.
- Around line 1695-1765: The gated_delta_rule kernel assumes specific dtypes and
shape invariants (BF16 for q/k/v/a/b and float32 for A_log/dt_bias, K==128, HV
divisible by H, initial_state_source shape [B, HV, V, K] and matching K/V/H/B
dimensions) but doesn't validate them; add upfront assertions in
gated_delta_rule to check q/k/v/a/b dtypes, A_log/dt_bias dtypes, that
q.shape==(B,T,H,K) and v.shape==(B,T,HV,V), that K == 128 (or require K ==
k.shape[-1] and optionally enforce 128), that HV % H == 0, and that
initial_state_source (if provided) has dtype and shape [B, HV, V, K] and
matching B, HV, V, K; raise clear ValueError/AssertionError with messages
referencing the symbols (q, k, v, a, b, A_log, dt_bias, initial_state_source,
HV, K, H) so callers fail fast instead of producing silent corruption.
🧹 Nitpick comments (4)
flashinfer/cute_dsl/gated_delta_rule.py (1)

311-360: Remove unused o_head parameter to avoid lint noise.

It is never referenced inside process_first_token.

♻️ Suggested fix
-def process_first_token(
+def process_first_token(
     h_sh_chunk_curr,
     h_chunk,
     kq_chunk,
     k_sh,
     q_sh,
     v_sh,
     reduce_sh,
-    o_head,
+    _o_head,
     g_exp,
     beta,
     v_offset,
     pred_slot,
     warp_idx,
     lane_idx,
     k_base,
 ):
tests/gdn/test_decode_delta_rule.py (1)

622-653: Mark unused alpha as intentionally unused to silence lint.

This keeps the signature but avoids ARG001 noise.

♻️ Suggested fix
-def _test_improved_cutedsl_kernel(
+def _test_improved_cutedsl_kernel(
     dtype: str,
     batch_size: int,
     num_q_heads: int,
     num_k_heads: int,
     num_v_heads: int,
     head_size: int,
     seq_len: int,  # T=1,2,3,4
     scale: float,
-    alpha: bool,
+    _alpha: bool,
     beta: bool,
     seed: int | None = None,
 ):
benchmarks/bench_gdn_decode.py (2)

1835-1875: Avoid unused output parameter in wrapper.

Rename to _output to silence ARG001 without changing behavior.

♻️ Suggested fix
 def improved_cutedsl_gdn_wrapper(
@@
-    output: torch.Tensor,  # [B, T, HV, V] - unused, kernel returns output directly
+    _output: torch.Tensor,  # [B, T, HV, V] - unused, kernel returns output directly
     use_qk_l2norm: bool = True,
     softplus_beta: float = 1.0,
     softplus_threshold: float = 20.0,
 ):

2213-2292: Use float32 dt_bias in improved CuTe‑DSL benchmark.

The improved kernel path (and tests) treat dt_bias as float32; aligning the benchmark avoids unintended precision differences.

♻️ Suggested fix
-    dt_bias = torch.randn(num_sab_heads, dtype=dtype, device="cuda")
+    dt_bias = torch.randn(num_sab_heads, dtype=torch.float32, device="cuda")

Comment on lines +1695 to +1709
def gated_delta_rule(
A_log: torch.Tensor,
a: torch.Tensor,
dt_bias: torch.Tensor,
softplus_beta: float = 1.0,
softplus_threshold: float = 20.0,
q: Optional[torch.Tensor] = None,
k: Optional[torch.Tensor] = None,
v: Optional[torch.Tensor] = None,
b: Optional[torch.Tensor] = None,
initial_state_source: Optional[torch.Tensor] = None,
initial_state_indices: Optional[torch.Tensor] = None,
use_qk_l2norm_in_kernel: bool = True,
scale: Optional[float] = None,
) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

use_qk_l2norm_in_kernel is ignored — make behavior explicit.

The kernel always normalizes Q/K, so passing False silently produces incorrect semantics.

🛡️ Suggested guard
     A_log: torch.Tensor,
     a: torch.Tensor,
     dt_bias: torch.Tensor,
     softplus_beta: float = 1.0,
     softplus_threshold: float = 20.0,
@@
     use_qk_l2norm_in_kernel: bool = True,
     scale: Optional[float] = None,
 ) -> torch.Tensor:
@@
     global _compiled_kernels
+
+    if not use_qk_l2norm_in_kernel:
+        raise NotImplementedError(
+            "CuTe-DSL GDN kernel currently always applies Q/K L2 normalization"
+        )
🧰 Tools
🪛 Ruff (0.14.14)

[warning] 1706-1706: Unused function argument: initial_state_indices

(ARG001)


[warning] 1707-1707: Unused function argument: use_qk_l2norm_in_kernel

(ARG001)

🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/gated_delta_rule.py` around lines 1695 - 1709, The
parameter use_qk_l2norm_in_kernel on gated_delta_rule is currently ignored
because the kernel always normalizes Q/K; add an explicit guard at the top of
gated_delta_rule to either (a) implement a non-normalizing code path for Q/K or
(b) fail fast — e.g., if not use_qk_l2norm_in_kernel: raise
ValueError("use_qk_l2norm_in_kernel=False is not supported: kernel currently
always L2-normalizes q/k"), so callers get explicit behavior instead of silent
incorrect semantics; reference the gated_delta_rule function and the
use_qk_l2norm_in_kernel parameter when adding the check.

Comment on lines 1695 to 1765
def gated_delta_rule(
A_log: torch.Tensor,
a: torch.Tensor,
dt_bias: torch.Tensor,
softplus_beta: float = 1.0,
softplus_threshold: float = 20.0,
q: Optional[torch.Tensor] = None,
k: Optional[torch.Tensor] = None,
v: Optional[torch.Tensor] = None,
b: Optional[torch.Tensor] = None,
initial_state_source: Optional[torch.Tensor] = None,
initial_state_indices: Optional[torch.Tensor] = None,
use_qk_l2norm_in_kernel: bool = True,
scale: Optional[float] = None,
) -> torch.Tensor:
"""
Gated Delta Rule linear attention kernel.

Implements the Gated Delta Rule mechanism for decode-phase inference,
supporting sequence lengths T=1, T=2, T=3, T=4.

Args:
A_log: Log decay parameter [HV]
a: Alpha gate input [B, T, HV]
dt_bias: Delta-t bias [HV]
softplus_beta: Softplus beta parameter (default: 1.0)
softplus_threshold: Softplus threshold (default: 20.0)
q: Query tensor [B, T, H, K]
k: Key tensor [B, T, H, K]
v: Value tensor [B, T, HV, V]
b: Beta gate input [B, T, HV]
initial_state_source: H state [B, HV, V, K] (K-fast layout), modified in-place
initial_state_indices: Not used (for compatibility)
use_qk_l2norm_in_kernel: Whether to L2-normalize Q/K in kernel (default: True)
scale: Optional attention scale (default: 1/sqrt(K))

Returns:
output: [B, T, HV, V]

Example:
>>> B, T, H, K = 16, 1, 16, 128
>>> HV, V = 32, 128
>>> q = torch.randn(B, T, H, K, device='cuda', dtype=torch.bfloat16)
>>> k = torch.randn(B, T, H, K, device='cuda', dtype=torch.bfloat16)
>>> v = torch.randn(B, T, HV, V, device='cuda', dtype=torch.bfloat16)
>>> a = torch.randn(B, T, HV, device='cuda', dtype=torch.bfloat16)
>>> b = torch.randn(B, T, HV, device='cuda', dtype=torch.bfloat16)
>>> A_log = torch.randn(HV, device='cuda', dtype=torch.float32)
>>> dt_bias = torch.randn(HV, device='cuda', dtype=torch.float32)
>>> h_state = torch.randn(B, HV, V, K, device='cuda', dtype=torch.bfloat16)
>>> output = gated_delta_rule(
... A_log, a, dt_bias, q=q, k=k, v=v, b=b,
... initial_state_source=h_state
... )
"""
global _compiled_kernels

B, T, H, K = q.shape
assert T in [1, 2, 3, 4], f"Supported T=1,2,3,4, got T={T}"
HV = v.shape[2]
V = v.shape[3]

if scale is None:
scale = 1.0 / math.sqrt(K)

output = torch.empty(B, T, HV, V, device=q.device, dtype=q.dtype)

q_ = from_dlpack(q, assumed_align=32)
k_ = from_dlpack(k, assumed_align=32)
v_ = from_dlpack(v, assumed_align=32)
a_ = from_dlpack(a, assumed_align=32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add explicit dtype/shape validation for kernel invariants.

The kernel stores outputs and state as BF16 and assumes K/V=128 plus HV divisible by H. Without guards, callers can silently get corrupted outputs.

🛡️ Suggested validation
     global _compiled_kernels
 
     B, T, H, K = q.shape
     assert T in [1, 2, 3, 4], f"Supported T=1,2,3,4, got T={T}"
     HV = v.shape[2]
     V = v.shape[3]
+
+    if K != 128 or V != 128:
+        raise ValueError(f"CuTe-DSL GDN kernel expects K=V=128, got K={K}, V={V}")
+    if HV % H != 0:
+        raise ValueError(f"HV must be divisible by H (HV={HV}, H={H})")
+    if q.dtype != torch.bfloat16 or k.dtype != torch.bfloat16 or v.dtype != torch.bfloat16:
+        raise ValueError("CuTe-DSL GDN kernel expects q/k/v in torch.bfloat16")
+    if initial_state_source.dtype != torch.bfloat16:
+        raise ValueError("CuTe-DSL GDN kernel expects state in torch.bfloat16")
+    if A_log.dtype != torch.float32 or dt_bias.dtype != torch.float32:
+        raise ValueError("CuTe-DSL GDN kernel expects A_log/dt_bias in torch.float32")
+    if initial_state_source.shape != (B, HV, V, K):
+        raise ValueError(
+            f"Expected state shape [B, HV, V, K] = ({B}, {HV}, {V}, {K}), "
+            f"got {tuple(initial_state_source.shape)}"
+        )
🧰 Tools
🪛 Ruff (0.14.14)

[warning] 1706-1706: Unused function argument: initial_state_indices

(ARG001)


[warning] 1707-1707: Unused function argument: use_qk_l2norm_in_kernel

(ARG001)


[warning] 1752-1752: Unpacked variable H is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/gated_delta_rule.py` around lines 1695 - 1765, The
gated_delta_rule kernel assumes specific dtypes and shape invariants (BF16 for
q/k/v/a/b and float32 for A_log/dt_bias, K==128, HV divisible by H,
initial_state_source shape [B, HV, V, K] and matching K/V/H/B dimensions) but
doesn't validate them; add upfront assertions in gated_delta_rule to check
q/k/v/a/b dtypes, A_log/dt_bias dtypes, that q.shape==(B,T,H,K) and
v.shape==(B,T,HV,V), that K == 128 (or require K == k.shape[-1] and optionally
enforce 128), that HV % H == 0, and that initial_state_source (if provided) has
dtype and shape [B, HV, V, K] and matching B, HV, V, K; raise clear
ValueError/AssertionError with messages referencing the symbols (q, k, v, a, b,
A_log, dt_bias, initial_state_source, HV, K, H) so callers fail fast instead of
producing silent corruption.

@aditya-narayan5
Copy link

@ameynaik-hub Thanks for linking to the PR by @yzh119:

fix the docstring of the GDN prefill kernel; instead of N,H,K,V, it expects N,H,V,K

Could the maintainers clarify why the N,H,V,K layout is preferred?

Most GDN reference implementations (e.g., the FLA project) expect the recurrent state in N,H,K,V, and the reference used in the tests in PR #2276 follows the same convention, we explicitly transpose the kernel output from N,H,V,K for comparison against the reference.

PR #2370 already provides both a pre‑transpose (N,H,V,K) and a non‑transpose (N,H,K,V) path, and the inline comments recommend the non‑transpose version. I’ve searched the Qwen3 documentation and the official Hugging Face Transformers implementation (which uses FLA under the hood) and haven’t found any explicit requirement for the N,H,V,K layout.

Is there a downstream consumer or hardware target that mandates this layout?

@vadiklyutiy
Copy link

@ameynaik-hub Thanks for linking to the PR by @yzh119:

fix the docstring of the GDN prefill kernel; instead of N,H,K,V, it expects N,H,V,K

Could the maintainers clarify why the N,H,V,K layout is preferred?

Most GDN reference implementations (e.g., the FLA project) expect the recurrent state in N,H,K,V, and the reference used in the tests in PR #2276 follows the same convention, we explicitly transpose the kernel output from N,H,V,K for comparison against the reference.

PR #2370 already provides both a pre‑transpose (N,H,V,K) and a non‑transpose (N,H,K,V) path, and the inline comments recommend the non‑transpose version. I’ve searched the Qwen3 documentation and the official Hugging Face Transformers implementation (which uses FLA under the hood) and haven’t found any explicit requirement for the N,H,V,K layout.

Is there a downstream consumer or hardware target that mandates this layout?

@yzh119 convinced me in Slack that k-last is better. vllm already moved to k-last layout, and this is btw give perf improvement itself:)
Also FI's GDN decode use K-last layout.
@yzh119 @guangyunh-nv could you pls comment on that.

@ameynaik-hub
Copy link
Author

generally speaking if it is k-last (k contiguous), you can vectorize the reads along k-dim, and the inner products with k and q can be performed parallelly across rows more easily.

@guangyunh-nv
Copy link
Collaborator

@vadiklyutiy The reason for K-last is as follows,

In prefill, you need to repeatly update State in the kernel mainloop, so you better put state in register (hopper). This means you put S in the A operand. This further avoids the repeated transposition and updating the State in B in SMEM style kernel. It is very critical to kernel performance.

I believe Blackwell has the same constraint as due to the TMEM accesse pattern.

…scing and bank conflicts

Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@flashinfer/cute_dsl/gated_delta_rule.py`:
- Around line 312-359: Remove the dead o_head parameter from the
process_first_token function signature and all call sites; update the function
signature for process_first_token to omit o_head and adjust every invocation
(notably in process_vchunk_unified_234) to stop passing that argument, ensuring
argument order and remaining parameters (h_sh_chunk_curr, h_chunk, kq_chunk,
k_sh, q_sh, v_sh, reduce_sh, g_exp, beta, v_offset, pred_slot, warp_idx,
lane_idx, k_base) remain unchanged; rebuild and run tests to confirm no other
call sites reference o_head.
🧹 Nitpick comments (2)
flashinfer/cute_dsl/gated_delta_rule.py (2)

696-1097: Seqlen=1 kernel: 4 chunks are manually unrolled with identical logic.

Chunks 0–3 (lines 807–1096) repeat the same pattern: wait for async load → decay H from SMEM → compute pred → cross-warp reduce → delta update → compute output → reduce → write-back. This is ~290 lines of near-identical code with only the chunk index and V-offset changing. Consider factoring the per-chunk logic into a helper (similar to process_vchunk_unified_234), passing the chunk-specific SMEM buffer and offset as parameters. The persistent-K optimization and pipelining schedule can still be preserved.


2009-2026: --generate-line-info left in production compile options.

This flag adds debug line info to the compiled kernel binary, increasing code size and potentially impacting instruction cache utilization. Consider making it conditional (e.g., controlled by an env var or debug flag) or removing it for production builds.

Comment on lines +312 to +359
def process_first_token(
h_sh_chunk_curr,
h_chunk,
kq_chunk,
k_sh,
q_sh,
v_sh,
reduce_sh,
o_head,
g_exp,
beta,
v_offset,
pred_slot,
warp_idx,
lane_idx,
k_base,
):
"""
Process the first token in a V-chunk (T=0).
- Load K from SMEM
- Decay H from SMEM and compute pred
- Cross-warp reduce pred (uses pred_slot)
- Update H with delta
- Load Q and compute output
Returns: out (partial output, not yet reduced)
"""
# Load K for this token
load_kq_chunk_from_smem(k_sh, kq_chunk, k_base)

# Decay H from SMEM and compute pred = H * K
pred = decay_h_from_smem_and_compute_pred(
h_sh_chunk_curr, h_chunk, kq_chunk, g_exp, lane_idx, k_base
)

# Reduce pred across warps (slot 0 for first token)
pred_final = cross_warp_reduce_single(
reduce_sh, pred_slot, warp_idx, lane_idx, pred
)

# Compute delta and update H
v_delta = (v_sh[v_offset + lane_idx] - pred_final) * beta
update_h_with_delta(h_chunk, kq_chunk, v_delta)

# Load Q and compute output
load_kq_chunk_from_smem(q_sh, kq_chunk, k_base)
out = compute_output(h_chunk, kq_chunk)

return out
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Unused o_head parameter in process_first_token.

The o_head parameter (line 320) is never referenced in the function body. The first token's output is returned as out and stored by the next token's processing function. This dead parameter adds confusion and wastes a register.

Proposed fix
 `@cute.jit`
 def process_first_token(
     h_sh_chunk_curr,
     h_chunk,
     kq_chunk,
     k_sh,
     q_sh,
     v_sh,
     reduce_sh,
-    o_head,
     g_exp,
     beta,
     v_offset,
     pred_slot,
     warp_idx,
     lane_idx,
     k_base,
 ):

And update all call sites accordingly (in process_vchunk_unified_234).

🧰 Tools
🪛 Ruff (0.15.0)

[warning] 320-320: Unused function argument: o_head

(ARG001)

🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/gated_delta_rule.py` around lines 312 - 359, Remove the
dead o_head parameter from the process_first_token function signature and all
call sites; update the function signature for process_first_token to omit o_head
and adjust every invocation (notably in process_vchunk_unified_234) to stop
passing that argument, ensuring argument order and remaining parameters
(h_sh_chunk_curr, h_chunk, kq_chunk, k_sh, q_sh, v_sh, reduce_sh, g_exp, beta,
v_offset, pred_slot, warp_idx, lane_idx, k_base) remain unchanged; rebuild and
run tests to confirm no other call sites reference o_head.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@flashinfer/cute_dsl/gated_delta_rule.py`:
- Around line 1179-1182: The comment is incorrect: reduce_sh allocated via
smem.allocate_tensor(cutlass.Float32, cute.make_layout((8, 32, 4), stride=(128,
4, 1))) yields 4-way bank conflicts for reduce_sh[slot, lane_idx, warp_idx];
either update the comment to state it’s 4-way conflicted or change the layout to
a bank-conflict-free interleaving (e.g., replace the stride tuple with an
interleaved BCF stride such as using 33-element pitch for the lane dimension
like stride=(33*4, 1, 33) or equivalent) so accesses across lane_idx map to
unique banks; modify the smem.allocate_tensor call for reduce_sh accordingly and
keep the symbol names reduce_sh, smem.allocate_tensor, cute.make_layout and
cutlass.Float32 to locate the change.
🧹 Nitpick comments (2)
flashinfer/cute_dsl/gated_delta_rule.py (2)

696-1097: Seqlen=1 kernel: 4 nearly-identical chunk blocks (~300 lines of repeated code).

Chunks 0–3 (lines 807–1096) repeat the same decay → pred → reduce → delta → output → H-writeback pattern with only the SMEM buffer name and V offset changing. While inlining is understandable for pipelining (the async load/store interleaving differs per chunk), the core compute block could be extracted into a helper that takes the chunk's H SMEM buffer and V offset, similar to how process_vchunk_unified_234 factored the seqlen≥2 path. This would cut ~200 lines with no runtime cost since CuTe-DSL will inline @cute.jit helpers at compile time.


1998-2007: Low-BS dispatch threshold (B <= 4) is hard-coded without documented rationale.

The T=1 path selects the low-BS kernel (4 CTAs per batch×head) when B <= 4. This threshold likely comes from empirical benchmarking, but there's no comment or constant explaining the crossover point. If the benchmark characteristics change (e.g., different GPU, different HV), this threshold could become stale. Consider extracting 4 to a named constant (e.g., _LOWBS_THRESHOLD = 4) with a brief comment about the rationale.

Comment on lines +1179 to +1182
# Bank-conflict-free reduce_sh: [slot, lane_idx, warp_idx]
reduce_sh = smem.allocate_tensor(
cutlass.Float32, cute.make_layout((8, 32, 4), stride=(128, 4, 1))
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

reduce_sh layout has 4-way bank conflicts despite the "bank-conflict-free" comment.

With stride (128, 4, 1), reduce_sh[slot, lane_idx, warp_idx] maps to offset slot*128 + lane_idx*4 + warp_idx. Within a single warp reading a fixed warp_idx, the bank is (lane_idx*4 + warp_idx) % 32, which means lanes 0, 8, 16, 24 (and similar groups) map to the same bank — yielding 4-way conflicts on both writes and reads.

Compare with the seqlen=1 kernel's pred_sh/out_sh at lines 760–764, which uses stride (1, 32) — truly bank-conflict-free since lane_idx + warp_idx*32 gives each lane a unique bank.

Consider either updating the comment to reflect the actual conflict level, or switching to a BCF layout (e.g., (8, 32, 4) with stride (33*4, 1, 33) or similar interleaving).

🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/gated_delta_rule.py` around lines 1179 - 1182, The
comment is incorrect: reduce_sh allocated via
smem.allocate_tensor(cutlass.Float32, cute.make_layout((8, 32, 4), stride=(128,
4, 1))) yields 4-way bank conflicts for reduce_sh[slot, lane_idx, warp_idx];
either update the comment to state it’s 4-way conflicted or change the layout to
a bank-conflict-free interleaving (e.g., replace the stride tuple with an
interleaved BCF stride such as using 33-element pitch for the lane dimension
like stride=(33*4, 1, 33) or equivalent) so accesses across lane_idx map to
unique banks; modify the smem.allocate_tensor call for reduce_sh accordingly and
keep the symbol names reduce_sh, smem.allocate_tensor, cute.make_layout and
cutlass.Float32 to locate the change.

…mproved_cutedsl

- gdn_decode.py: optional backend for pretranspose when bf16 state, T<=4, K=V=128; dispatch to cute_dsl gated_delta_rule.
- bench_gdn_decode.py: rename improved_cutedsl to gdn_decode_klast_bf16_state (--version, wrapper, result keys).
- test_decode_delta_rule.py: same rename; add test_pretranspose_api_uses_gdn_decode_klast_bf16_state for API dispatch.

Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
@ameynaik-hub
Copy link
Author

ameynaik-hub commented Feb 11, 2026

cute dsl bf16 h perf on B200

batch T time(us)
1 1 3.62
1 2 5.86
1 3 6.94
1 4 7.79
4 1 4.9
4 2 6.62
4 3 7.65
4 4 8.64
8 1 7.04
8 2 8.67
8 3 9.95
8 4 11.39
16 1 9.74
16 2 12.9
16 3 15.15
16 4 17.54
32 1 16.06
32 2 21.57
32 3 25.79
32 4 28.61
64 1 27.01
64 2 34.88
64 3 40.67
64 4 49.76
128 1 48.9
128 2 60.7
128 3 72.26
128 4 89.04
256 1 91.66
256 2 112.54
256 3 134.88
256 4 166.32
512 1 177.06
512 2 214.53
512 3 258.75
512 4 320.8

@kahyunnam
Copy link
Collaborator

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !309 has been created, and the CI pipeline #43822304 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #43822304: 11/20 passed

- Bank-conflict-free cross-warp reductions
- Async H memory loading with aggressive pipelining
- BF16 tensors with FP32 compute for numerical stability
- GQA (grouped-query attention) support with configurable H (query) and HV (value) heads
Copy link
Collaborator

@guangyunh-nv guangyunh-nv Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GQA: HQ = <IntegerRatio> * HV
GVA: <IntegerRatio> * HQ = HV, I confirmed this naming in a group chat with GDN paper author :)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@guangyunh-nv sorry do you mean it is called GVA and not GQA?

@aditya-narayan5
Copy link

Thanks to everyone for clarifying, this helped :) @vadiklyutiy @guangyunh-nv @ameynaik-hub

@guangyunh-nv
Copy link
Collaborator

Out of curisity, why the perf degenerate when batch goes from 1 to 4.

batch T time(us)
1 1 3.62
4 1 4.9

I'd assume you have not reached the parallelism limit and bandwidth limit, then the time should remain a constant.

@ameynaik-hub
Copy link
Author

Out of curisity, why the perf degenerate when batch goes from 1 to 4.

@guangyunh-nv for BS <=4 I use 4 ctas per head. so for BS=1 it has 324 = 128 CTAs and for BS=4 it has 324*4 = 512 CTAs
at this regime it is latency bound, MIO throttle appears to be the reason based on ncu reports.

@@ -137,6 +137,7 @@ def blockwise_linear_attention(
| torch.Tensor = 1.0, # float or tensor with num_elems == num_qo_heads
decay_exponent_offset=0,
kv_dtype: torch.dtype = torch.float32,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should remove kv_dtype, it is named as kv_dtype because in some earlier papers call the state directly as KV. It is not dtype for K and V

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, can you please confirm if the change is okay? Thanks!

- Consolidate to single state_dtype parameter across all reference functions
  - Remove duplicate kv_dtype parameter from blockwise_linear_attention(),
    delta_rule(), and blockwise_delta_rule()
  - Update test_prefill_delta_rule.py to use state_dtype consistently

- Remove benchmark_gated_delta_rule.py from git tracking (keep locally)
  - Add to .gitignore for local development use only

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM, some minor nits.

q: torch.Tensor, # [B, T, H_Q, K] where T=1,2,3,4
k: torch.Tensor, # [B, T, H_K, K]
v: torch.Tensor, # [B, T, HV, V]
state: torch.Tensor, # [B, HV, V, K] - K-fast layout (pretranspose)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
state: torch.Tensor, # [B, HV, V, K] - K-fast layout (pretranspose)
state: torch.Tensor, # [B, HV, V, K] - K-last layout (pretranspose)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we moved it to flashinfer/gdn_decode.py (or making it a module)? We don't want to keep the directoryflashinfer/cute_dsl in the future.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants