Ameyn/gdn decode cutedsl kernel#2498
Ameyn/gdn decode cutedsl kernel#2498ameynaik-hub wants to merge 9 commits intoflashinfer-ai:mainfrom
Conversation
Implements high-performance Gated Delta Rule linear attention kernel supporting fixed sequence lengths T=1, T=2, T=3, T=4 using NVIDIA CuTe-DSL. Key features: - H state layout: K-last [B, HV, V, K] where K is the contiguous (fastest) dimension - Unified kernel architecture: T=2/3/4 share a single compile-time specialized kernel via Constexpr dispatch; T=1 uses separate kernel with persistent K optimization - L2-normalized Q/K with configurable scale - Gated exponential decay via softplus - Delta rule updates: v_delta = beta * (v - pred) - Bank-conflict-free cross-warp reductions - Async H memory loading with aggressive pipelining - BF16 tensors with FP32 compute for numerical stability - GQA (grouped-query attention) support Also includes: - benchmark_gated_delta_rule.py: Simple benchmark script for measuring kernel perf - Updated __init__.py exports Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Summary of ChangesHello @ameynaik-hub, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a highly optimized CUDA kernel for the Gated Delta Rule linear attention mechanism, specifically tailored for decode-phase inference. The implementation, built with NVIDIA CuTe-DSL, provides high performance for fixed sequence lengths of 1, 2, 3, and 4. It incorporates advanced optimizations such as a K-last H state layout, L2-normalized Q/K, gated exponential decay, delta rule updates, and aggressive asynchronous memory pipelining, all while maintaining numerical stability with BF16 tensors and FP32 compute. A new benchmark script is also included to evaluate the kernel's performance. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds a CuTe-DSL Gated Delta Rule CUDA kernel (seq_len 1–4) with a Python launcher/class, a standalone benchmark, benchmark integration, reference dtype/state handling updates, and tests exercising the new kernel and BF16 state paths. Changes
Sequence Diagram(s)sequenceDiagram
participant PythonAPI as Python API
participant KernelCache as Kernel Cache
participant CuTe as CuTe Compiler
participant GPU as GPU Kernel
participant Memory as GPU Memory
PythonAPI->>PythonAPI: validate inputs (q,k,v,A_log,a,dt_bias,seq_len)
PythonAPI->>KernelCache: lookup compiled kernel (T, dtypes)
alt cached
KernelCache-->>PythonAPI: return cached kernel
else not cached
PythonAPI->>CuTe: compile kernel for (T, dtypes)
CuTe-->>KernelCache: store compiled kernel
KernelCache-->>PythonAPI: return compiled kernel
end
PythonAPI->>GPU: launch kernel on CUDA stream with args
GPU->>Memory: async load Q,K,V,gates,state blocks
GPU->>GPU: l2-normalize, gated decay, delta-rule updates, reductions
GPU->>Memory: write output [B,T,HV,V] and updated state
GPU-->>PythonAPI: signal completion / return tensor
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces a high-performance Gated Delta Rule linear attention kernel using CuTe-DSL, supporting sequence lengths from 1 to 4. The implementation includes separate, optimized kernels for T=1 and a unified kernel for T=2,3,4, along with a benchmark script. My review focuses on the new kernel implementation in gated_delta_rule.py. I've identified a few areas for improvement: the seqlen=1 kernel contains significant code duplication that could be refactored for better maintainability. The kernel caching strategy could lead to performance issues with dynamic batch sizes due to unnecessary recompilations. Finally, the exported GatedDeltaRuleKernel class appears to be incomplete or unused. Addressing these points will improve the robustness and performance of the new kernel.
| stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) | ||
|
|
||
| # Check cache | ||
| cache_key = (T, B) |
There was a problem hiding this comment.
The kernel caching mechanism uses the batch size B as part of the cache key. Because the kernel is compiled with concrete tensor shapes (derived from from_dlpack), this will trigger a recompilation for every new batch size encountered at runtime. For applications with dynamic batch sizes, which are common in inference scenarios, this can lead to significant performance degradation due to repeated JIT compilation.
To address this, I recommend using symbolic dimensions for the batch size during compilation to generate a more generic kernel. This can be achieved by using cute.runtime.make_fake_compact_tensor with a cute.sym_int() for the batch dimension, instead of using from_dlpack directly on the input tensors within the compilation path. You can find an example of this pattern in flashinfer/cute_dsl/add_rmsnorm_fp4quant.py. This change would allow you to remove the batch size from the cache key, preventing unnecessary recompilations and improving runtime performance.
| def gated_delta_rule_decode_kernel_seqlen1( | ||
| gQ: cute.Tensor, | ||
| gK: cute.Tensor, | ||
| gV: cute.Tensor, | ||
| ga: cute.Tensor, | ||
| gb: cute.Tensor, | ||
| gA_log: cute.Tensor, | ||
| gdt_bias: cute.Tensor, | ||
| gH: cute.Tensor, | ||
| gO: cute.Tensor, | ||
| scale: cutlass.Float32, | ||
| softplus_beta: cutlass.Float32, | ||
| softplus_threshold: cutlass.Float32, | ||
| eps: cutlass.Float32, | ||
| ): | ||
| """ | ||
| Seqlen=1 kernel with persistent K optimization. | ||
| OPTIMIZATIONS: | ||
| 1. PERSISTENT K IN REGISTERS ONLY: K[k_base:k_base+32] kept for entire kernel | ||
| Q is reloaded per chunk (lower register pressure than V3) | ||
| 2. AGGRESSIVE PIPELINING: Load chunks 2 ahead, store during next compute | ||
| 3. [4,32] CROSS-WARP REDUCTION: Correct lane-preserving reduction | ||
| """ | ||
| tidx, _, _ = cute.arch.thread_idx() | ||
| bidx, _, _ = cute.arch.block_idx() | ||
|
|
||
| HV = cutlass.Int32(gV.shape[2]) | ||
| H = cutlass.Int32(gQ.shape[2]) | ||
|
|
||
| batch_idx = bidx // HV | ||
| value_head_idx = bidx % HV | ||
| query_head_idx = value_head_idx // (HV // H) | ||
|
|
||
| smem = utils.SmemAllocator() | ||
|
|
||
| # Compute gates using shared helper | ||
| alpha = ga[(batch_idx, 0, value_head_idx)].to(cutlass.Float32) | ||
| beta_raw = gb[(batch_idx, 0, value_head_idx)].to(cutlass.Float32) | ||
| A_log_val = gA_log[value_head_idx] | ||
| dt_bias_val = gdt_bias[value_head_idx] | ||
| g_exp, beta = compute_single_gate( | ||
| alpha, beta_raw, dt_bias_val, A_log_val, softplus_beta, softplus_threshold | ||
| ) | ||
|
|
||
| # Allocate SMEM | ||
| h_sh_chunk0 = smem.allocate_tensor( | ||
| cutlass.BFloat16, cute.make_layout((32, 128), stride=(H_SMEM_STRIDE, 1)) | ||
| ) | ||
| h_sh_chunk1 = smem.allocate_tensor( | ||
| cutlass.BFloat16, cute.make_layout((32, 128), stride=(H_SMEM_STRIDE, 1)) | ||
| ) | ||
| h_sh_chunk2 = smem.allocate_tensor( | ||
| cutlass.BFloat16, cute.make_layout((32, 128), stride=(H_SMEM_STRIDE, 1)) | ||
| ) | ||
| h_sh_chunk3 = smem.allocate_tensor( | ||
| cutlass.BFloat16, cute.make_layout((32, 128), stride=(H_SMEM_STRIDE, 1)) | ||
| ) | ||
|
|
||
| q_sh = smem.allocate_tensor(cutlass.Float32, 128) | ||
| k_sh = smem.allocate_tensor(cutlass.Float32, 128) | ||
|
|
||
| pred_sh = smem.allocate_tensor(cutlass.Float32, cute.make_layout((4, 32))) | ||
| out_sh = smem.allocate_tensor(cutlass.Float32, cute.make_layout((4, 32))) | ||
|
|
||
| h_global = gH[(batch_idx, value_head_idx, None, None)] | ||
|
|
||
| # Launch first 2 async loads | ||
| load_h_chunk_async(h_sh_chunk0, h_global, tidx, 0) | ||
| nvvm.cp_async_commit_group() | ||
| load_h_chunk_async(h_sh_chunk1, h_global, tidx, 32) | ||
| nvvm.cp_async_commit_group() | ||
|
|
||
| # L2 normalization | ||
| q_head = gQ[(batch_idx, 0, query_head_idx, None)] | ||
| k_head = gK[(batch_idx, 0, query_head_idx, None)] | ||
|
|
||
| warp_idx = tidx // 32 | ||
| lane_idx = tidx % 32 | ||
|
|
||
| # Use shared helper for Q/K normalization (only warp 0 does the work) | ||
| if warp_idx == 0: | ||
| normalize_and_store_qk_to_smem(q_head, k_head, q_sh, k_sh, lane_idx, scale, eps) | ||
|
|
||
| cute.arch.sync_threads() | ||
|
|
||
| # Load V | ||
| v_head = gV[(batch_idx, 0, value_head_idx, None)] | ||
| v_sh = smem.allocate_tensor(cutlass.Float32, 128) | ||
| v_sh[tidx] = v_head[tidx].to(cutlass.Float32) | ||
|
|
||
| # Registers: h_chunk + k_chunk (persistent) + qk_temp (reused for Q) | ||
| h_chunk = cute.make_rmem_tensor((32,), cutlass.Float32) | ||
| k_chunk = cute.make_rmem_tensor((32,), cutlass.Float32) # PERSISTENT K! | ||
| qk_temp = cute.make_rmem_tensor((32,), cutlass.Float32) | ||
|
|
||
| k_base = warp_idx * 32 | ||
|
|
||
| # Load K ONCE - keep for entire kernel | ||
| for i in cutlass.range_constexpr(32): | ||
| k_chunk[i] = k_sh[k_base + i] | ||
|
|
||
| h_out = gH[(batch_idx, value_head_idx, None, None)] | ||
| o_head = gO[(batch_idx, 0, value_head_idx, None)] | ||
|
|
||
| # ======================================================================== | ||
| # CHUNK 0 | ||
| # ======================================================================== | ||
| nvvm.cp_async_wait_group(1) | ||
| cute.arch.sync_threads() | ||
|
|
||
| pred = cutlass.Float32(0.0) | ||
| pred2 = cutlass.Float32(0.0) | ||
| for i in cutlass.range_constexpr(0, 32, 2): | ||
| h_chunk[i], h_chunk[i + 1] = cute.arch.fma_packed_f32x2( | ||
| src_a=( | ||
| h_sh_chunk0[lane_idx, k_base + i].to(cutlass.Float32), | ||
| h_sh_chunk0[lane_idx, k_base + i + 1].to(cutlass.Float32), | ||
| ), | ||
| src_b=(g_exp, g_exp), | ||
| src_c=(cutlass.Float32(0.0), cutlass.Float32(0.0)), | ||
| ) | ||
| for i in cutlass.range_constexpr(0, 32, 2): | ||
| pred, pred2 = cute.arch.fma_packed_f32x2( | ||
| src_a=(h_chunk[i], h_chunk[i + 1]), | ||
| src_b=(k_chunk[i], k_chunk[i + 1]), | ||
| src_c=(pred, pred2), | ||
| ) | ||
| pred = pred + pred2 | ||
|
|
||
| pred_sh[warp_idx, lane_idx] = pred | ||
| cute.arch.sync_threads() | ||
| pred_final = ( | ||
| pred_sh[0, lane_idx] | ||
| + pred_sh[1, lane_idx] | ||
| + pred_sh[2, lane_idx] | ||
| + pred_sh[3, lane_idx] | ||
| ) | ||
|
|
||
| v_val = (v_sh[lane_idx] - pred_final) * beta | ||
|
|
||
| for i in cutlass.range_constexpr(0, 32, 2): | ||
| h_chunk[i], h_chunk[i + 1] = cute.arch.fma_packed_f32x2( | ||
| src_a=(k_chunk[i], k_chunk[i + 1]), | ||
| src_b=(v_val, v_val), | ||
| src_c=(h_chunk[i], h_chunk[i + 1]), | ||
| ) | ||
|
|
||
| # Load Q for output computation | ||
| for i in cutlass.range_constexpr(32): | ||
| qk_temp[i] = q_sh[k_base + i] | ||
|
|
||
| out = cutlass.Float32(0.0) | ||
| out2 = cutlass.Float32(0.0) | ||
| for i in cutlass.range_constexpr(0, 32, 2): | ||
| out, out2 = cute.arch.fma_packed_f32x2( | ||
| src_a=(h_chunk[i], h_chunk[i + 1]), | ||
| src_b=(qk_temp[i], qk_temp[i + 1]), | ||
| src_c=(out, out2), | ||
| ) | ||
| out = out + out2 | ||
|
|
||
| out_sh[warp_idx, lane_idx] = out | ||
| cute.arch.sync_threads() | ||
| out_final = ( | ||
| out_sh[0, lane_idx] | ||
| + out_sh[1, lane_idx] | ||
| + out_sh[2, lane_idx] | ||
| + out_sh[3, lane_idx] | ||
| ) | ||
|
|
||
| write_h_chunk_to_smem(h_chunk, h_sh_chunk0, lane_idx, k_base) | ||
| if warp_idx == 0: | ||
| o_head[lane_idx] = out_final.to(cutlass.BFloat16) | ||
|
|
||
| # ======================================================================== | ||
| # CHUNK 1 | ||
| # ======================================================================== | ||
| nvvm.cp_async_wait_group(0) | ||
| cute.arch.sync_threads() | ||
|
|
||
| load_h_chunk_async(h_sh_chunk2, h_global, tidx, 64) | ||
| nvvm.cp_async_commit_group() | ||
| load_h_chunk_async(h_sh_chunk3, h_global, tidx, 96) | ||
| nvvm.cp_async_commit_group() | ||
|
|
||
| store_h_smem_to_gmem(h_sh_chunk0, h_out, tidx, 0) | ||
|
|
||
| pred = cutlass.Float32(0.0) | ||
| pred2 = cutlass.Float32(0.0) | ||
| for i in cutlass.range_constexpr(0, 32, 2): | ||
| h_chunk[i], h_chunk[i + 1] = cute.arch.fma_packed_f32x2( | ||
| src_a=( | ||
| h_sh_chunk1[lane_idx, k_base + i].to(cutlass.Float32), | ||
| h_sh_chunk1[lane_idx, k_base + i + 1].to(cutlass.Float32), | ||
| ), | ||
| src_b=(g_exp, g_exp), | ||
| src_c=(cutlass.Float32(0.0), cutlass.Float32(0.0)), | ||
| ) | ||
| for i in cutlass.range_constexpr(0, 32, 2): | ||
| pred, pred2 = cute.arch.fma_packed_f32x2( | ||
| src_a=(h_chunk[i], h_chunk[i + 1]), | ||
| src_b=(k_chunk[i], k_chunk[i + 1]), | ||
| src_c=(pred, pred2), | ||
| ) | ||
| pred = pred + pred2 | ||
|
|
||
| pred_sh[warp_idx, lane_idx] = pred | ||
| cute.arch.sync_threads() | ||
| pred_final = ( | ||
| pred_sh[0, lane_idx] | ||
| + pred_sh[1, lane_idx] | ||
| + pred_sh[2, lane_idx] | ||
| + pred_sh[3, lane_idx] | ||
| ) | ||
|
|
||
| v_val = (v_sh[32 + lane_idx] - pred_final) * beta | ||
|
|
||
| for i in cutlass.range_constexpr(0, 32, 2): | ||
| h_chunk[i], h_chunk[i + 1] = cute.arch.fma_packed_f32x2( | ||
| src_a=(k_chunk[i], k_chunk[i + 1]), | ||
| src_b=(v_val, v_val), | ||
| src_c=(h_chunk[i], h_chunk[i + 1]), | ||
| ) | ||
|
|
||
| for i in cutlass.range_constexpr(32): | ||
| qk_temp[i] = q_sh[k_base + i] | ||
|
|
||
| out = cutlass.Float32(0.0) | ||
| out2 = cutlass.Float32(0.0) | ||
| for i in cutlass.range_constexpr(0, 32, 2): | ||
| out, out2 = cute.arch.fma_packed_f32x2( | ||
| src_a=(h_chunk[i], h_chunk[i + 1]), | ||
| src_b=(qk_temp[i], qk_temp[i + 1]), | ||
| src_c=(out, out2), | ||
| ) | ||
| out = out + out2 | ||
|
|
||
| out_sh[warp_idx, lane_idx] = out | ||
| cute.arch.sync_threads() | ||
| out_final = ( | ||
| out_sh[0, lane_idx] | ||
| + out_sh[1, lane_idx] | ||
| + out_sh[2, lane_idx] | ||
| + out_sh[3, lane_idx] | ||
| ) | ||
|
|
||
| write_h_chunk_to_smem(h_chunk, h_sh_chunk1, lane_idx, k_base) | ||
| if warp_idx == 0: | ||
| o_head[32 + lane_idx] = out_final.to(cutlass.BFloat16) | ||
|
|
||
| # ======================================================================== | ||
| # CHUNK 2 | ||
| # ======================================================================== | ||
| nvvm.cp_async_wait_group(1) | ||
| cute.arch.sync_threads() | ||
|
|
||
| store_h_smem_to_gmem(h_sh_chunk1, h_out, tidx, 32) | ||
|
|
||
| pred = cutlass.Float32(0.0) | ||
| pred2 = cutlass.Float32(0.0) | ||
| for i in cutlass.range_constexpr(0, 32, 2): | ||
| h_chunk[i], h_chunk[i + 1] = cute.arch.fma_packed_f32x2( | ||
| src_a=( | ||
| h_sh_chunk2[lane_idx, k_base + i].to(cutlass.Float32), | ||
| h_sh_chunk2[lane_idx, k_base + i + 1].to(cutlass.Float32), | ||
| ), | ||
| src_b=(g_exp, g_exp), | ||
| src_c=(cutlass.Float32(0.0), cutlass.Float32(0.0)), | ||
| ) | ||
| for i in cutlass.range_constexpr(0, 32, 2): | ||
| pred, pred2 = cute.arch.fma_packed_f32x2( | ||
| src_a=(h_chunk[i], h_chunk[i + 1]), | ||
| src_b=(k_chunk[i], k_chunk[i + 1]), | ||
| src_c=(pred, pred2), | ||
| ) | ||
| pred = pred + pred2 | ||
|
|
||
| pred_sh[warp_idx, lane_idx] = pred | ||
| cute.arch.sync_threads() | ||
| pred_final = ( | ||
| pred_sh[0, lane_idx] | ||
| + pred_sh[1, lane_idx] | ||
| + pred_sh[2, lane_idx] | ||
| + pred_sh[3, lane_idx] | ||
| ) | ||
|
|
||
| v_val = (v_sh[64 + lane_idx] - pred_final) * beta | ||
|
|
||
| for i in cutlass.range_constexpr(0, 32, 2): | ||
| h_chunk[i], h_chunk[i + 1] = cute.arch.fma_packed_f32x2( | ||
| src_a=(k_chunk[i], k_chunk[i + 1]), | ||
| src_b=(v_val, v_val), | ||
| src_c=(h_chunk[i], h_chunk[i + 1]), | ||
| ) | ||
|
|
||
| for i in cutlass.range_constexpr(32): | ||
| qk_temp[i] = q_sh[k_base + i] | ||
|
|
||
| out = cutlass.Float32(0.0) | ||
| out2 = cutlass.Float32(0.0) | ||
| for i in cutlass.range_constexpr(0, 32, 2): | ||
| out, out2 = cute.arch.fma_packed_f32x2( | ||
| src_a=(h_chunk[i], h_chunk[i + 1]), | ||
| src_b=(qk_temp[i], qk_temp[i + 1]), | ||
| src_c=(out, out2), | ||
| ) | ||
| out = out + out2 | ||
|
|
||
| out_sh[warp_idx, lane_idx] = out | ||
| cute.arch.sync_threads() | ||
| out_final = ( | ||
| out_sh[0, lane_idx] | ||
| + out_sh[1, lane_idx] | ||
| + out_sh[2, lane_idx] | ||
| + out_sh[3, lane_idx] | ||
| ) | ||
|
|
||
| write_h_chunk_to_smem(h_chunk, h_sh_chunk2, lane_idx, k_base) | ||
| if warp_idx == 0: | ||
| o_head[64 + lane_idx] = out_final.to(cutlass.BFloat16) | ||
|
|
||
| # ======================================================================== | ||
| # CHUNK 3 | ||
| # ======================================================================== | ||
| nvvm.cp_async_wait_group(0) | ||
| cute.arch.sync_threads() | ||
|
|
||
| store_h_smem_to_gmem(h_sh_chunk2, h_out, tidx, 64) | ||
|
|
||
| pred = cutlass.Float32(0.0) | ||
| pred2 = cutlass.Float32(0.0) | ||
| for i in cutlass.range_constexpr(0, 32, 2): | ||
| h_chunk[i], h_chunk[i + 1] = cute.arch.fma_packed_f32x2( | ||
| src_a=( | ||
| h_sh_chunk3[lane_idx, k_base + i].to(cutlass.Float32), | ||
| h_sh_chunk3[lane_idx, k_base + i + 1].to(cutlass.Float32), | ||
| ), | ||
| src_b=(g_exp, g_exp), | ||
| src_c=(cutlass.Float32(0.0), cutlass.Float32(0.0)), | ||
| ) | ||
| for i in cutlass.range_constexpr(0, 32, 2): | ||
| pred, pred2 = cute.arch.fma_packed_f32x2( | ||
| src_a=(h_chunk[i], h_chunk[i + 1]), | ||
| src_b=(k_chunk[i], k_chunk[i + 1]), | ||
| src_c=(pred, pred2), | ||
| ) | ||
| pred = pred + pred2 | ||
|
|
||
| pred_sh[warp_idx, lane_idx] = pred | ||
| cute.arch.sync_threads() | ||
| pred_final = ( | ||
| pred_sh[0, lane_idx] | ||
| + pred_sh[1, lane_idx] | ||
| + pred_sh[2, lane_idx] | ||
| + pred_sh[3, lane_idx] | ||
| ) | ||
|
|
||
| v_val = (v_sh[96 + lane_idx] - pred_final) * beta | ||
|
|
||
| for i in cutlass.range_constexpr(0, 32, 2): | ||
| h_chunk[i], h_chunk[i + 1] = cute.arch.fma_packed_f32x2( | ||
| src_a=(k_chunk[i], k_chunk[i + 1]), | ||
| src_b=(v_val, v_val), | ||
| src_c=(h_chunk[i], h_chunk[i + 1]), | ||
| ) | ||
|
|
||
| for i in cutlass.range_constexpr(32): | ||
| qk_temp[i] = q_sh[k_base + i] | ||
|
|
||
| out = cutlass.Float32(0.0) | ||
| out2 = cutlass.Float32(0.0) | ||
| for i in cutlass.range_constexpr(0, 32, 2): | ||
| out, out2 = cute.arch.fma_packed_f32x2( | ||
| src_a=(h_chunk[i], h_chunk[i + 1]), | ||
| src_b=(qk_temp[i], qk_temp[i + 1]), | ||
| src_c=(out, out2), | ||
| ) | ||
| out = out + out2 | ||
|
|
||
| out_sh[warp_idx, lane_idx] = out | ||
| cute.arch.sync_threads() | ||
| out_final = ( | ||
| out_sh[0, lane_idx] | ||
| + out_sh[1, lane_idx] | ||
| + out_sh[2, lane_idx] | ||
| + out_sh[3, lane_idx] | ||
| ) | ||
|
|
||
| write_h_chunk_to_smem(h_chunk, h_sh_chunk3, lane_idx, k_base) | ||
| if warp_idx == 0: | ||
| o_head[96 + lane_idx] = out_final.to(cutlass.BFloat16) | ||
|
|
||
| cute.arch.sync_threads() | ||
| store_h_smem_to_gmem(h_sh_chunk3, h_out, tidx, 96) | ||
|
|
There was a problem hiding this comment.
The logic for processing each of the four V-dimension chunks is nearly identical, leading to significant code duplication within this kernel. This makes the code difficult to read and maintain, as any change in the chunk processing logic needs to be manually replicated four times.
Consider refactoring the duplicated computation into a @cute.jit helper function. This function could encapsulate the logic for decaying the state, computing predictions, updating the state, and calculating the output for a single chunk. The main kernel would then manage the pipelined loading and storing of data while calling this helper function for each chunk. This would greatly improve code clarity and maintainability.
| class GatedDeltaRuleKernel: | ||
| """ | ||
| Gated Delta Rule Kernel for linear attention decode. | ||
|
|
||
| This kernel implements the Gated Delta Rule mechanism supporting sequence | ||
| lengths T=1, T=2, T=3, T=4 with optimized CUDA implementations. | ||
|
|
||
| Key features: | ||
| - T=1: Persistent K in registers with aggressive pipelining | ||
| - T=2/3/4: Unified kernel with compile-time Constexpr specialization | ||
| - L2-normalized Q/K with configurable scale | ||
| - Gated exponential decay via softplus | ||
| - Bank-conflict-free cross-warp reductions | ||
| - Async H memory loading | ||
|
|
||
| Args: | ||
| seq_len: Sequence length (1, 2, 3, or 4) | ||
| """ | ||
|
|
||
| def __init__(self, seq_len: int): | ||
| assert seq_len in [1, 2, 3, 4], f"Supported seq_len: 1,2,3,4, got {seq_len}" | ||
| self.seq_len = seq_len | ||
| self._compiled_kernel = None | ||
|
|
||
| def _get_launch_fn(self): | ||
| if self.seq_len == 1: | ||
| return gated_delta_rule_launch_seqlen1 | ||
| elif self.seq_len == 2: | ||
| return gated_delta_rule_launch_seqlen2 | ||
| elif self.seq_len == 3: | ||
| return gated_delta_rule_launch_seqlen3 | ||
| else: | ||
| return gated_delta_rule_launch_seqlen4 |
There was a problem hiding this comment.
The GatedDeltaRuleKernel class is defined and exported as part of the public API, but it appears to be incomplete and is not used by the main gated_delta_rule function. The _compiled_kernel member is initialized to None and is never assigned a compiled kernel, and there is no method to execute it. The functional entry point gated_delta_rule implements its own caching and launch logic.
If this class is intended for future use, it should be fully implemented. If it is obsolete or a work-in-progress, it should either be completed or removed to avoid confusing users of the library.
There was a problem hiding this comment.
Actionable comments posted: 4
🤖 Fix all issues with AI agents
In `@flashinfer/cute_dsl/gated_delta_rule.py`:
- Around line 1773-1774: The cache_key currently set to (T, B) is too narrow and
can cause incorrect kernel reuse; update the cache key creation (variable
cache_key) to include the relevant tensor dimensions such as H, HV, K, and V
(the head count and per-head dims) so it uniquely identifies tensor shapes used
by the kernel—derive these from the involved tensors (the q/k/v/proj shapes or
whatever variables represent heads and head sizes) and include them alongside T
and B when building cache_key to prevent shape-mismatch cache hits.
- Around line 1666-1680: GatedDeltaRuleKernel currently only stores seq_len and
_compiled_kernel and provides _get_launch_fn but is never used by
gated_delta_rule or _compiled_kernels; either finish the class by adding a
public execution API (e.g., an __call__ or execute method that compiles/looks up
the kernel into _compiled_kernel, uses _get_launch_fn to obtain the launch
function and runs it with the same signature as the module-level implementation)
or remove the class from the public surface and migrate any kernel-caching logic
into the existing module-level _compiled_kernels paths; reference the class name
GatedDeltaRuleKernel, its attribute _compiled_kernel, method _get_launch_fn, and
the public gated_delta_rule/_compiled_kernels cache when making the change.
- Around line 1746-1749: The code accesses q, k, v, b and initial_state_source
without null checks even though their annotations allow None; add explicit
validation at the start of the function (e.g., in the function containing the
lines that reference q.shape and v.shape) to raise a clear ValueError if any of
q, k, v, b, or initial_state_source is None (or alternatively update the
function signature to remove Optional typing for these parameters), and ensure
any downstream use of from_dlpack(initial_state_source, ...) only occurs after
confirming initial_state_source is not None; reference the variables q, k, v, b
and initial_state_source and the shapes accessed (q.shape, v.shape[2],
v.shape[3]) so the checks are placed before those accesses.
- Around line 1153-1165: Unconditional shared-memory allocations q_sh2, k_sh2,
q_sh3, k_sh3 and the extra v_sh* buffers are always created even when NUM_TOKENS
is smaller; wrap those smem.allocate_tensor(...) calls in compile-time guards so
unused buffers are eliminated. Specifically, change the allocations for
q_sh2/k_sh2 to be inside an if constexpr (NUM_TOKENS >= 3) block and the
allocations for q_sh3/k_sh3 to be inside an if constexpr (NUM_TOKENS == 4) block
(and similarly guard v_sh2/v_sh3 as appropriate), keeping the same
smem.allocate_tensor(cutlass.Float32, 128) calls and names so later code (gate
computations) still references the same identifiers when compiled in. Ensure the
guards use the compile-time NUM_TOKENS parameter so the compiler can drop unused
allocations.
🧹 Nitpick comments (5)
flashinfer/cute_dsl/gated_delta_rule.py (3)
118-136: Potential numerical stability concern incompute_single_gate.The softplus implementation handles large positive values but not large negative values of
beta_x. Whenbeta_xis very negative,cute.math.exp(beta_x)approaches zero, which is fine numerically. However, the sigmoid computation at line 134 could have issues with large positivebeta_rawvalues wherecute.math.exp(-beta_raw)underflows to 0, resulting inbeta = 1.0(acceptable), but large negativebeta_rawcausesexp(-beta_raw)to overflow.Consider adding a symmetric threshold check for the sigmoid similar to the softplus handling.
💡 Optional: Add numerical safeguard for sigmoid
g = -cute.math.exp(A_log_val) * softplus_x g_exp = cute.math.exp(g) - beta = cutlass.Float32(1.0) / (cutlass.Float32(1.0) + cute.math.exp(-beta_raw)) + # Numerically stable sigmoid + if beta_raw >= cutlass.Float32(0.0): + beta = cutlass.Float32(1.0) / (cutlass.Float32(1.0) + cute.math.exp(-beta_raw)) + else: + exp_beta = cute.math.exp(beta_raw) + beta = exp_beta / (cutlass.Float32(1.0) + exp_beta) return g_exp, beta
306-354: Consider removing unusedo_headparameter.The
o_headparameter is passed but never used inprocess_first_token. The first token's output is returned and stored later during subsequent token processing. Removing it would clarify the function's contract.♻️ Remove unused parameter
`@cute.jit` def process_first_token( h_sh_chunk_curr, h_chunk, kq_chunk, k_sh, q_sh, v_sh, reduce_sh, - o_head, g_exp, beta, v_offset, pred_slot, warp_idx, lane_idx, k_base, ):Note: This would require updating all call sites in
process_vchunk_unified_234.
1746-1746: Unpacked variableHis unused.The variable
His extracted fromq.shapebut never used in the function. Consider prefixing with underscore to indicate intentional non-use.♻️ Prefix unused variable
- B, T, H, K = q.shape + B, T, _H, K = q.shapeflashinfer/cute_dsl/benchmark_gated_delta_rule.py (2)
82-82: Import path may fail depending on execution context.The relative import
from gated_delta_rule import gated_delta_ruleassumes the script is run from within theflashinfer/cute_dsl/directory or that directory is insys.path. This will fail if run from the repository root or viapython -m.Consider using absolute import for robustness.
♻️ Use absolute import
def main(): - from gated_delta_rule import gated_delta_rule + from flashinfer.cute_dsl.gated_delta_rule import gated_delta_rule
116-127: Loop variable capture in closure - latent bug risk.The
run_kernelfunction capturesinputsandstateby reference, not by value. While this works correctly here because the closure is consumed immediately viabenchmark(), it's a latent bug that would manifest if the code were refactored (e.g., collecting closures to run later).♻️ Bind loop variables explicitly
- def run_kernel(): + def run_kernel(inputs=inputs, state=state): return gated_delta_rule( A_log=inputs["A_log"], a=inputs["a"], dt_bias=inputs["dt_bias"], q=inputs["q"], k=inputs["k"], v=inputs["v"], b=inputs["b"], initial_state_source=state, scale=inputs["scale"], )
| # Q/K buffers for token 2 (only for NUM_TOKENS >= 3) | ||
| q_sh2 = smem.allocate_tensor(cutlass.Float32, 128) | ||
| k_sh2 = smem.allocate_tensor(cutlass.Float32, 128) | ||
|
|
||
| # Q/K buffers for token 3 (only for NUM_TOKENS=4) | ||
| q_sh3 = smem.allocate_tensor(cutlass.Float32, 128) | ||
| k_sh3 = smem.allocate_tensor(cutlass.Float32, 128) | ||
|
|
||
| # V buffers | ||
| v_sh0 = smem.allocate_tensor(cutlass.Float32, 128) | ||
| v_sh1 = smem.allocate_tensor(cutlass.Float32, 128) | ||
| v_sh2 = smem.allocate_tensor(cutlass.Float32, 128) | ||
| v_sh3 = smem.allocate_tensor(cutlass.Float32, 128) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
cd /tmp && git ls-files flashinfer/cute_dsl/gated_delta_rule.pyRepository: flashinfer-ai/flashinfer
Length of output: 138
🏁 Script executed:
wc -l flashinfer/cute_dsl/gated_delta_rule.pyRepository: flashinfer-ai/flashinfer
Length of output: 112
🏁 Script executed:
sed -n '1140,1180p' flashinfer/cute_dsl/gated_delta_rule.pyRepository: flashinfer-ai/flashinfer
Length of output: 1718
🏁 Script executed:
# Search for NUM_TOKENS usage and conditional patterns around this section
sed -n '1100,1200p' flashinfer/cute_dsl/gated_delta_rule.py | head -100Repository: flashinfer-ai/flashinfer
Length of output: 3977
🏁 Script executed:
# Look for if statements guarding buffer allocation
rg -n "if.*NUM_TOKENS" flashinfer/cute_dsl/gated_delta_rule.py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 357
🌐 Web query:
CuTe DSL Constexpr shared memory allocation dead code elimination
💡 Result:
In CuTe / CuTeDSL (CUTLASS), whether “shared-memory allocation” and the associated copies/loops disappear is mostly a question of whether the compiler can prove things at compile time (i.e., via types and constexpr / if constexpr) and whether the backend can safely apply dead-store / dead-code elimination across the generated code.
1) “constexpr shared memory allocation” in practice (CuTe/CUTLASS)
- CuTe is designed so that tile shapes, layouts, and smem footprints are carried in the type system; when those are compile-time constants, the compiler can fully specialize code (unroll, constant-fold addresses, remove unused objects/paths).
- Practically, that means: prefer CuTe’s compile-time control-flow (
if constexpr,cute::static_for, type-based shapes/layouts) rather than runtime loops/conditions if you expect code (and sometimes even smem objects) to vanish. - CUTLASS’ own guidelines explicitly lean on
if constexpr-based exhaustiveness and “unreachable” patterns to ensure unwanted branches don’t survive compilation (and to avoid “missing return” / unreachable-path issues) [1].
2) Why you can see “dead-looking” stores (local or shared) even at -O3
What you’re describing matches a real pattern people have observed with CuTeDSL-generated kernels: stores to .local (stack) that appear unused even though the values are consumed directly from registers in nearby FFMA instructions.
A recent CuTeDSL example shows exactly this: STL.128 stores to local memory immediately followed by FFMA using the same registers, with no obvious reloads in between [2]. A plausible cause suggested in-thread is conservative aliasing / overlap concerns (compiler can’t prove the store is irrelevant because it can’t fully prove what may alias what) [2].
Also, NVIDIA engineers have long emphasized an important split:
- PTX is intermediate; optimizations (including dead-code elimination) may happen later in
ptxas, and you should judge final behavior by SASS [3]. - However, in your case you’re already looking at SASS and still seeing the stores, which usually means one of:
- the backend is being conservative (aliasing, lifetime issues),
- the store participates in some internal lowering pattern (e.g., preserving values across a region/edge the optimizer doesn’t fully reason about),
- or it’s simply a backend missed-optimization / bug (these have been reported historically for redundant local loads/stores) [4].
3) Concrete knobs that usually decide whether it optimizes away
- Make indices compile-time: if a loop induction variable controls fragment selection/addressing, switching to
cute::static_for/ unrolling often removes unexpected local traffic (your own observation aligns with this). - Prefer
if constexprfor optional smem paths so unused code paths don’t exist in the instantiated kernel (CUTLASS recommends these patterns) [1]. - Avoid debug/per-function compilation modes when judging DCE (e.g.,
-G, device debug) because they commonly inhibit optimization (this is a common explanation for “dead” locals in PTX/SASS discussions) [3]. - If aliasing is plausible, try enforcing non-aliasing where applicable (CUDA C++:
__restrict__on raw pointers; with CuTe views it’s trickier, but the same principle applies). The forum reply explicitly raises aliasing as a reason the compiler may keep the stores [2].
If you want, paste a minimal CuTeDSL repro (kernel + compile flags + GPU arch + CUDA version) and I’ll point out exactly which parts prevent compile-time proving (and what to change so smem/lmem traffic is actually eliminated).
Sources: CUTLASS programming guidelines [1]; NVIDIA forum thread showing CuTeDSL dead-looking STL.128 in SASS and aliasing hypothesis [2]; NVIDIA forum explanation that PTX is intermediate and DCE happens in ptxas / check SASS and debug builds inhibit this [3]; historical report of redundant lmem loads/stores as compiler/backend artifact/bug [4].
Unconditional SMEM allocation for unused token buffers wastes memory for smaller batch sizes.
The Q/K/V shared memory buffers are allocated unconditionally at compile time for all four tokens, even when NUM_TOKENS=2 or 3. Comments indicate buffers are "only for NUM_TOKENS >= 3" and "only for NUM_TOKENS=4", but no compile-time conditionals guard these allocations. The subsequent gate computations at lines 1195+ use if NUM_TOKENS runtime checks, which cannot eliminate the earlier allocations. For NUM_TOKENS=2, this means allocating q_sh2, k_sh2, q_sh3, k_sh3, and partial v_sh buffers unnecessarily.
While this trade-off is acceptable for a unified kernel (avoiding code duplication and supporting multiple sequence lengths), the SMEM footprint could be reduced by wrapping allocations in if constexpr guards tied to the NUM_TOKENS parameter, allowing true compile-time elimination rather than runtime conditionals.
🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/gated_delta_rule.py` around lines 1153 - 1165,
Unconditional shared-memory allocations q_sh2, k_sh2, q_sh3, k_sh3 and the extra
v_sh* buffers are always created even when NUM_TOKENS is smaller; wrap those
smem.allocate_tensor(...) calls in compile-time guards so unused buffers are
eliminated. Specifically, change the allocations for q_sh2/k_sh2 to be inside an
if constexpr (NUM_TOKENS >= 3) block and the allocations for q_sh3/k_sh3 to be
inside an if constexpr (NUM_TOKENS == 4) block (and similarly guard v_sh2/v_sh3
as appropriate), keeping the same smem.allocate_tensor(cutlass.Float32, 128)
calls and names so later code (gate computations) still references the same
identifiers when compiled in. Ensure the guards use the compile-time NUM_TOKENS
parameter so the compiler can drop unused allocations.
| def __init__(self, seq_len: int): | ||
| assert seq_len in [1, 2, 3, 4], f"Supported seq_len: 1,2,3,4, got {seq_len}" | ||
| self.seq_len = seq_len | ||
| self._compiled_kernel = None | ||
|
|
||
| def _get_launch_fn(self): | ||
| if self.seq_len == 1: | ||
| return gated_delta_rule_launch_seqlen1 | ||
| elif self.seq_len == 2: | ||
| return gated_delta_rule_launch_seqlen2 | ||
| elif self.seq_len == 3: | ||
| return gated_delta_rule_launch_seqlen3 | ||
| else: | ||
| return gated_delta_rule_launch_seqlen4 | ||
|
|
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# Check if GatedDeltaRuleKernel is used anywhere in the codebase
rg -n 'GatedDeltaRuleKernel' --type=py -A 2 -B 2Repository: flashinfer-ai/flashinfer
Length of output: 900
🏁 Script executed:
# Read the full GatedDeltaRuleKernel class implementation
cat -n flashinfer/cute_dsl/gated_delta_rule.py | sed -n '1645,1750p'Repository: flashinfer-ai/flashinfer
Length of output: 4690
🏁 Script executed:
# Check for the gated_delta_rule function and its implementation
rg -n 'def gated_delta_rule' --type=py -A 30Repository: flashinfer-ai/flashinfer
Length of output: 21951
🏁 Script executed:
# Check if gated_delta_rule function uses GatedDeltaRuleKernel internally
sed -n '1689,1800p' flashinfer/cute_dsl/gated_delta_rule.py | grep -i 'GatedDeltaRuleKernel\|_compiled_kernel'Repository: flashinfer-ai/flashinfer
Length of output: 192
🏁 Script executed:
# Search for any usage of _compiled_kernel attribute across the file
rg -n '_compiled_kernel' flashinfer/cute_dsl/gated_delta_rule.pyRepository: flashinfer-ai/flashinfer
Length of output: 366
GatedDeltaRuleKernel class is incomplete and unused.
The class initializes _compiled_kernel = None but never uses it, has no execution method (no __call__, execute, or similar), and is not instantiated by the public gated_delta_rule function. The module uses a module-level _compiled_kernels cache instead. Consider either completing this class with an execution interface or removing it from the public API if it's not intended for direct use.
🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/gated_delta_rule.py` around lines 1666 - 1680,
GatedDeltaRuleKernel currently only stores seq_len and _compiled_kernel and
provides _get_launch_fn but is never used by gated_delta_rule or
_compiled_kernels; either finish the class by adding a public execution API
(e.g., an __call__ or execute method that compiles/looks up the kernel into
_compiled_kernel, uses _get_launch_fn to obtain the launch function and runs it
with the same signature as the module-level implementation) or remove the class
from the public surface and migrate any kernel-caching logic into the existing
module-level _compiled_kernels paths; reference the class name
GatedDeltaRuleKernel, its attribute _compiled_kernel, method _get_launch_fn, and
the public gated_delta_rule/_compiled_kernels cache when making the change.
| B, T, H, K = q.shape | ||
| assert T in [1, 2, 3, 4], f"Supported T=1,2,3,4, got T={T}" | ||
| HV = v.shape[2] | ||
| V = v.shape[3] |
There was a problem hiding this comment.
Missing input validation for required parameters.
The function signature declares q, k, v, b, and initial_state_source as Optional, but the code unconditionally accesses them (e.g., q.shape, v.shape[2], from_dlpack(initial_state_source, ...)). If any are None, this will raise unhelpful errors.
Either validate these inputs explicitly or change the type hints to reflect they're actually required.
🛡️ Proposed fix: Add validation
global _compiled_kernels
+ if q is None or k is None or v is None or b is None or initial_state_source is None:
+ raise ValueError("q, k, v, b, and initial_state_source are required parameters")
+
B, T, H, K = q.shape
assert T in [1, 2, 3, 4], f"Supported T=1,2,3,4, got T={T}"Or change signatures to non-Optional:
- q: Optional[torch.Tensor] = None,
- k: Optional[torch.Tensor] = None,
- v: Optional[torch.Tensor] = None,
- b: Optional[torch.Tensor] = None,
- initial_state_source: Optional[torch.Tensor] = None,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ b: torch.Tensor,
+ initial_state_source: torch.Tensor,🧰 Tools
🪛 Ruff (0.14.14)
[warning] 1746-1746: Unpacked variable H is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/gated_delta_rule.py` around lines 1746 - 1749, The code
accesses q, k, v, b and initial_state_source without null checks even though
their annotations allow None; add explicit validation at the start of the
function (e.g., in the function containing the lines that reference q.shape and
v.shape) to raise a clear ValueError if any of q, k, v, b, or
initial_state_source is None (or alternatively update the function signature to
remove Optional typing for these parameters), and ensure any downstream use of
from_dlpack(initial_state_source, ...) only occurs after confirming
initial_state_source is not None; reference the variables q, k, v, b and
initial_state_source and the shapes accessed (q.shape, v.shape[2], v.shape[3])
so the checks are placed before those accesses.
| # Check cache | ||
| cache_key = (T, B) |
There was a problem hiding this comment.
Cache key may cause unnecessary recompilations or cache misses.
The cache key (T, B) doesn't account for tensor shapes beyond batch size and sequence length. If H, HV, K, or V change between calls with the same (T, B), the cached kernel might be invoked with incompatible tensor shapes, potentially causing silent correctness issues or crashes.
Consider including relevant shape dimensions in the cache key.
🐛 Proposed fix
# Check cache
- cache_key = (T, B)
+ cache_key = (T, B, H, HV, K, V)
if cache_key not in _compiled_kernels:🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/gated_delta_rule.py` around lines 1773 - 1774, The
cache_key currently set to (T, B) is too narrow and can cause incorrect kernel
reuse; update the cache key creation (variable cache_key) to include the
relevant tensor dimensions such as H, HV, K, and V (the head count and per-head
dims) so it uniquely identifies tensor shapes used by the kernel—derive these
from the involved tensors (the q/k/v/proj shapes or whatever variables represent
heads and head sizes) and include them alongside T and B when building cache_key
to prevent shape-mismatch cache hits.
|
@ameynaik-hub would you mind comparing with https://github.com/flashinfer-ai/flashinfer/blob/main/flashinfer/gdn_decode.py ? It's better to unify the file location and interface. |
|
@yzh119 compared benchmark here #2493 (comment) https://github.com/flashinfer-ai/flashinfer/blob/main/flashinfer/gdn_decode.py kernel supports on other hand, this kernel only support V-major (K contiguous , k -last dim) Any suggestions on how to merge? |
and test as well |
I think it's worth to copy results here |
|
New kernel is always better than old with pretranspose except batch=1 that we can excuse. Does it make sense to keep old pretranspose varian? |
|
@ameynaik-hub I might have overlooked details during reading but does anyone have a source for this:
|
@vadiklyutiy since it has support for T>4 (seqlen) and also support for v-last dim of |
As I remember there were several different interfaces/functions. I meant that processes case seq_len=1, k-last. |
|
@ameynaik-hub can you reuse the gdn decode interface and set your implementation as default when we use k-layout (V-major) layout and T <=4? |
There was a problem hiding this comment.
@ameynaik-hub this LGTM as the initial PR to just get the kernel into the Flashinfer codebase, approved. (still need a codeowner approval for this directory to merge ... maybe @yzh119 or @bkryu ). We can follow up with another PR for some work for the API integration + Flashinfer testing/benchmarking. Let me know if you have the bandwidth to work on this integration; if not, I can probably work on it next week.
|
Hi @kahyunnam , I encourage not to push more code to The required work to unify the existing gdn decode and this PR should be minimal, I can help if you need any assitance with it. |
|
|
||
|
|
||
| def benchmark( | ||
| func, num_iterations=100, n_warmup=10, flush_l2=True, use_dummy_matmul=True |
There was a problem hiding this comment.
We already have benchmarking APIs https://docs.flashinfer.ai/generated/flashinfer.testing.bench_gpu_time.html#flashinfer.testing.bench_gpu_time, please refer to https://github.com/flashinfer-ai/flashinfer/blob/main/benchmarks/bench_gdn_decode.py#L1088-L1095 on how to use these APIs.
There was a problem hiding this comment.
I just curious why that wasn't addressed and pointed in #2370 but was here?
| else: # T == 4 | ||
| launch_fn = gated_delta_rule_launch_seqlen4 | ||
|
|
||
| _compiled_kernels[cache_key] = cute.compile( |
There was a problem hiding this comment.
please enable tvm-ffi with options="--enable-tvm-ffi",, otherwise the host side overhead of converting torch.tensor to dlpack would be non-neglibile.
Reference can be found at #2279
@yzh119 yeah I can do that. I noticed that my kernel utilizes bf16 inputs for the state |
|
cc @guangyunh-nv (gdn prefill's author) do you have any insights on whether it's okay to use f16 state? |
|
FP16 state might not be a good choice due to the dynamic range limitation. The range is not guaranteed and purely depends on the model activation dynamics. So this purely depends on the model behavior. Datapoint: BF16 maybe safer, but the accuracy would be a problem. If algorithm side decide that this is tolerable, then I think we should support it. I think what we can do at the moment is to tolerate all kinds of state input, and learn from the production environment feedback :) |
|
Have you tested the end-to-end accuracy with FP16 SSM state? Does the performance remain unchanged on common benchmarks such as MMLU, GSM8K, etc.? |
I don't think this is right approach. Imagine car manufacturer introduce new self driving car and say let learn from road environment. Don't think a lot of people would be happy. |
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@flashinfer/cute_dsl/gated_delta_rule.py`:
- Around line 1695-1709: The parameter use_qk_l2norm_in_kernel on
gated_delta_rule is currently ignored because the kernel always normalizes Q/K;
add an explicit guard at the top of gated_delta_rule to either (a) implement a
non-normalizing code path for Q/K or (b) fail fast — e.g., if not
use_qk_l2norm_in_kernel: raise ValueError("use_qk_l2norm_in_kernel=False is not
supported: kernel currently always L2-normalizes q/k"), so callers get explicit
behavior instead of silent incorrect semantics; reference the gated_delta_rule
function and the use_qk_l2norm_in_kernel parameter when adding the check.
- Around line 1695-1765: The gated_delta_rule kernel assumes specific dtypes and
shape invariants (BF16 for q/k/v/a/b and float32 for A_log/dt_bias, K==128, HV
divisible by H, initial_state_source shape [B, HV, V, K] and matching K/V/H/B
dimensions) but doesn't validate them; add upfront assertions in
gated_delta_rule to check q/k/v/a/b dtypes, A_log/dt_bias dtypes, that
q.shape==(B,T,H,K) and v.shape==(B,T,HV,V), that K == 128 (or require K ==
k.shape[-1] and optionally enforce 128), that HV % H == 0, and that
initial_state_source (if provided) has dtype and shape [B, HV, V, K] and
matching B, HV, V, K; raise clear ValueError/AssertionError with messages
referencing the symbols (q, k, v, a, b, A_log, dt_bias, initial_state_source,
HV, K, H) so callers fail fast instead of producing silent corruption.
🧹 Nitpick comments (4)
flashinfer/cute_dsl/gated_delta_rule.py (1)
311-360: Remove unusedo_headparameter to avoid lint noise.It is never referenced inside
process_first_token.♻️ Suggested fix
-def process_first_token( +def process_first_token( h_sh_chunk_curr, h_chunk, kq_chunk, k_sh, q_sh, v_sh, reduce_sh, - o_head, + _o_head, g_exp, beta, v_offset, pred_slot, warp_idx, lane_idx, k_base, ):tests/gdn/test_decode_delta_rule.py (1)
622-653: Mark unusedalphaas intentionally unused to silence lint.This keeps the signature but avoids ARG001 noise.
♻️ Suggested fix
-def _test_improved_cutedsl_kernel( +def _test_improved_cutedsl_kernel( dtype: str, batch_size: int, num_q_heads: int, num_k_heads: int, num_v_heads: int, head_size: int, seq_len: int, # T=1,2,3,4 scale: float, - alpha: bool, + _alpha: bool, beta: bool, seed: int | None = None, ):benchmarks/bench_gdn_decode.py (2)
1835-1875: Avoid unusedoutputparameter in wrapper.Rename to
_outputto silence ARG001 without changing behavior.♻️ Suggested fix
def improved_cutedsl_gdn_wrapper( @@ - output: torch.Tensor, # [B, T, HV, V] - unused, kernel returns output directly + _output: torch.Tensor, # [B, T, HV, V] - unused, kernel returns output directly use_qk_l2norm: bool = True, softplus_beta: float = 1.0, softplus_threshold: float = 20.0, ):
2213-2292: Use float32dt_biasin improved CuTe‑DSL benchmark.The improved kernel path (and tests) treat
dt_biasas float32; aligning the benchmark avoids unintended precision differences.♻️ Suggested fix
- dt_bias = torch.randn(num_sab_heads, dtype=dtype, device="cuda") + dt_bias = torch.randn(num_sab_heads, dtype=torch.float32, device="cuda")
| def gated_delta_rule( | ||
| A_log: torch.Tensor, | ||
| a: torch.Tensor, | ||
| dt_bias: torch.Tensor, | ||
| softplus_beta: float = 1.0, | ||
| softplus_threshold: float = 20.0, | ||
| q: Optional[torch.Tensor] = None, | ||
| k: Optional[torch.Tensor] = None, | ||
| v: Optional[torch.Tensor] = None, | ||
| b: Optional[torch.Tensor] = None, | ||
| initial_state_source: Optional[torch.Tensor] = None, | ||
| initial_state_indices: Optional[torch.Tensor] = None, | ||
| use_qk_l2norm_in_kernel: bool = True, | ||
| scale: Optional[float] = None, | ||
| ) -> torch.Tensor: |
There was a problem hiding this comment.
use_qk_l2norm_in_kernel is ignored — make behavior explicit.
The kernel always normalizes Q/K, so passing False silently produces incorrect semantics.
🛡️ Suggested guard
A_log: torch.Tensor,
a: torch.Tensor,
dt_bias: torch.Tensor,
softplus_beta: float = 1.0,
softplus_threshold: float = 20.0,
@@
use_qk_l2norm_in_kernel: bool = True,
scale: Optional[float] = None,
) -> torch.Tensor:
@@
global _compiled_kernels
+
+ if not use_qk_l2norm_in_kernel:
+ raise NotImplementedError(
+ "CuTe-DSL GDN kernel currently always applies Q/K L2 normalization"
+ )🧰 Tools
🪛 Ruff (0.14.14)
[warning] 1706-1706: Unused function argument: initial_state_indices
(ARG001)
[warning] 1707-1707: Unused function argument: use_qk_l2norm_in_kernel
(ARG001)
🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/gated_delta_rule.py` around lines 1695 - 1709, The
parameter use_qk_l2norm_in_kernel on gated_delta_rule is currently ignored
because the kernel always normalizes Q/K; add an explicit guard at the top of
gated_delta_rule to either (a) implement a non-normalizing code path for Q/K or
(b) fail fast — e.g., if not use_qk_l2norm_in_kernel: raise
ValueError("use_qk_l2norm_in_kernel=False is not supported: kernel currently
always L2-normalizes q/k"), so callers get explicit behavior instead of silent
incorrect semantics; reference the gated_delta_rule function and the
use_qk_l2norm_in_kernel parameter when adding the check.
| def gated_delta_rule( | ||
| A_log: torch.Tensor, | ||
| a: torch.Tensor, | ||
| dt_bias: torch.Tensor, | ||
| softplus_beta: float = 1.0, | ||
| softplus_threshold: float = 20.0, | ||
| q: Optional[torch.Tensor] = None, | ||
| k: Optional[torch.Tensor] = None, | ||
| v: Optional[torch.Tensor] = None, | ||
| b: Optional[torch.Tensor] = None, | ||
| initial_state_source: Optional[torch.Tensor] = None, | ||
| initial_state_indices: Optional[torch.Tensor] = None, | ||
| use_qk_l2norm_in_kernel: bool = True, | ||
| scale: Optional[float] = None, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Gated Delta Rule linear attention kernel. | ||
|
|
||
| Implements the Gated Delta Rule mechanism for decode-phase inference, | ||
| supporting sequence lengths T=1, T=2, T=3, T=4. | ||
|
|
||
| Args: | ||
| A_log: Log decay parameter [HV] | ||
| a: Alpha gate input [B, T, HV] | ||
| dt_bias: Delta-t bias [HV] | ||
| softplus_beta: Softplus beta parameter (default: 1.0) | ||
| softplus_threshold: Softplus threshold (default: 20.0) | ||
| q: Query tensor [B, T, H, K] | ||
| k: Key tensor [B, T, H, K] | ||
| v: Value tensor [B, T, HV, V] | ||
| b: Beta gate input [B, T, HV] | ||
| initial_state_source: H state [B, HV, V, K] (K-fast layout), modified in-place | ||
| initial_state_indices: Not used (for compatibility) | ||
| use_qk_l2norm_in_kernel: Whether to L2-normalize Q/K in kernel (default: True) | ||
| scale: Optional attention scale (default: 1/sqrt(K)) | ||
|
|
||
| Returns: | ||
| output: [B, T, HV, V] | ||
|
|
||
| Example: | ||
| >>> B, T, H, K = 16, 1, 16, 128 | ||
| >>> HV, V = 32, 128 | ||
| >>> q = torch.randn(B, T, H, K, device='cuda', dtype=torch.bfloat16) | ||
| >>> k = torch.randn(B, T, H, K, device='cuda', dtype=torch.bfloat16) | ||
| >>> v = torch.randn(B, T, HV, V, device='cuda', dtype=torch.bfloat16) | ||
| >>> a = torch.randn(B, T, HV, device='cuda', dtype=torch.bfloat16) | ||
| >>> b = torch.randn(B, T, HV, device='cuda', dtype=torch.bfloat16) | ||
| >>> A_log = torch.randn(HV, device='cuda', dtype=torch.float32) | ||
| >>> dt_bias = torch.randn(HV, device='cuda', dtype=torch.float32) | ||
| >>> h_state = torch.randn(B, HV, V, K, device='cuda', dtype=torch.bfloat16) | ||
| >>> output = gated_delta_rule( | ||
| ... A_log, a, dt_bias, q=q, k=k, v=v, b=b, | ||
| ... initial_state_source=h_state | ||
| ... ) | ||
| """ | ||
| global _compiled_kernels | ||
|
|
||
| B, T, H, K = q.shape | ||
| assert T in [1, 2, 3, 4], f"Supported T=1,2,3,4, got T={T}" | ||
| HV = v.shape[2] | ||
| V = v.shape[3] | ||
|
|
||
| if scale is None: | ||
| scale = 1.0 / math.sqrt(K) | ||
|
|
||
| output = torch.empty(B, T, HV, V, device=q.device, dtype=q.dtype) | ||
|
|
||
| q_ = from_dlpack(q, assumed_align=32) | ||
| k_ = from_dlpack(k, assumed_align=32) | ||
| v_ = from_dlpack(v, assumed_align=32) | ||
| a_ = from_dlpack(a, assumed_align=32) |
There was a problem hiding this comment.
Add explicit dtype/shape validation for kernel invariants.
The kernel stores outputs and state as BF16 and assumes K/V=128 plus HV divisible by H. Without guards, callers can silently get corrupted outputs.
🛡️ Suggested validation
global _compiled_kernels
B, T, H, K = q.shape
assert T in [1, 2, 3, 4], f"Supported T=1,2,3,4, got T={T}"
HV = v.shape[2]
V = v.shape[3]
+
+ if K != 128 or V != 128:
+ raise ValueError(f"CuTe-DSL GDN kernel expects K=V=128, got K={K}, V={V}")
+ if HV % H != 0:
+ raise ValueError(f"HV must be divisible by H (HV={HV}, H={H})")
+ if q.dtype != torch.bfloat16 or k.dtype != torch.bfloat16 or v.dtype != torch.bfloat16:
+ raise ValueError("CuTe-DSL GDN kernel expects q/k/v in torch.bfloat16")
+ if initial_state_source.dtype != torch.bfloat16:
+ raise ValueError("CuTe-DSL GDN kernel expects state in torch.bfloat16")
+ if A_log.dtype != torch.float32 or dt_bias.dtype != torch.float32:
+ raise ValueError("CuTe-DSL GDN kernel expects A_log/dt_bias in torch.float32")
+ if initial_state_source.shape != (B, HV, V, K):
+ raise ValueError(
+ f"Expected state shape [B, HV, V, K] = ({B}, {HV}, {V}, {K}), "
+ f"got {tuple(initial_state_source.shape)}"
+ )🧰 Tools
🪛 Ruff (0.14.14)
[warning] 1706-1706: Unused function argument: initial_state_indices
(ARG001)
[warning] 1707-1707: Unused function argument: use_qk_l2norm_in_kernel
(ARG001)
[warning] 1752-1752: Unpacked variable H is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/gated_delta_rule.py` around lines 1695 - 1765, The
gated_delta_rule kernel assumes specific dtypes and shape invariants (BF16 for
q/k/v/a/b and float32 for A_log/dt_bias, K==128, HV divisible by H,
initial_state_source shape [B, HV, V, K] and matching K/V/H/B dimensions) but
doesn't validate them; add upfront assertions in gated_delta_rule to check
q/k/v/a/b dtypes, A_log/dt_bias dtypes, that q.shape==(B,T,H,K) and
v.shape==(B,T,HV,V), that K == 128 (or require K == k.shape[-1] and optionally
enforce 128), that HV % H == 0, and that initial_state_source (if provided) has
dtype and shape [B, HV, V, K] and matching B, HV, V, K; raise clear
ValueError/AssertionError with messages referencing the symbols (q, k, v, a, b,
A_log, dt_bias, initial_state_source, HV, K, H) so callers fail fast instead of
producing silent corruption.
|
@ameynaik-hub Thanks for linking to the PR by @yzh119:
Could the maintainers clarify why the N,H,V,K layout is preferred? Most GDN reference implementations (e.g., the FLA project) expect the recurrent state in N,H,K,V, and the reference used in the tests in PR #2276 follows the same convention, we explicitly transpose the kernel output from N,H,V,K for comparison against the reference. PR #2370 already provides both a pre‑transpose (N,H,V,K) and a non‑transpose (N,H,K,V) path, and the inline comments recommend the non‑transpose version. I’ve searched the Qwen3 documentation and the official Hugging Face Transformers implementation (which uses FLA under the hood) and haven’t found any explicit requirement for the N,H,V,K layout. Is there a downstream consumer or hardware target that mandates this layout? |
@yzh119 convinced me in Slack that k-last is better. vllm already moved to k-last layout, and this is btw give perf improvement itself:) |
|
generally speaking if it is k-last (k contiguous), you can vectorize the reads along k-dim, and the inner products with k and q can be performed parallelly across rows more easily. |
|
@vadiklyutiy The reason for K-last is as follows, In prefill, you need to repeatly update State in the kernel mainloop, so you better put state in register (hopper). This means you put S in the A operand. This further avoids the repeated transposition and updating the State in B in SMEM style kernel. It is very critical to kernel performance. I believe Blackwell has the same constraint as due to the TMEM accesse pattern. |
…scing and bank conflicts Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@flashinfer/cute_dsl/gated_delta_rule.py`:
- Around line 312-359: Remove the dead o_head parameter from the
process_first_token function signature and all call sites; update the function
signature for process_first_token to omit o_head and adjust every invocation
(notably in process_vchunk_unified_234) to stop passing that argument, ensuring
argument order and remaining parameters (h_sh_chunk_curr, h_chunk, kq_chunk,
k_sh, q_sh, v_sh, reduce_sh, g_exp, beta, v_offset, pred_slot, warp_idx,
lane_idx, k_base) remain unchanged; rebuild and run tests to confirm no other
call sites reference o_head.
🧹 Nitpick comments (2)
flashinfer/cute_dsl/gated_delta_rule.py (2)
696-1097: Seqlen=1 kernel: 4 chunks are manually unrolled with identical logic.Chunks 0–3 (lines 807–1096) repeat the same pattern: wait for async load → decay H from SMEM → compute pred → cross-warp reduce → delta update → compute output → reduce → write-back. This is ~290 lines of near-identical code with only the chunk index and V-offset changing. Consider factoring the per-chunk logic into a helper (similar to
process_vchunk_unified_234), passing the chunk-specific SMEM buffer and offset as parameters. The persistent-K optimization and pipelining schedule can still be preserved.
2009-2026:--generate-line-infoleft in production compile options.This flag adds debug line info to the compiled kernel binary, increasing code size and potentially impacting instruction cache utilization. Consider making it conditional (e.g., controlled by an env var or debug flag) or removing it for production builds.
| def process_first_token( | ||
| h_sh_chunk_curr, | ||
| h_chunk, | ||
| kq_chunk, | ||
| k_sh, | ||
| q_sh, | ||
| v_sh, | ||
| reduce_sh, | ||
| o_head, | ||
| g_exp, | ||
| beta, | ||
| v_offset, | ||
| pred_slot, | ||
| warp_idx, | ||
| lane_idx, | ||
| k_base, | ||
| ): | ||
| """ | ||
| Process the first token in a V-chunk (T=0). | ||
| - Load K from SMEM | ||
| - Decay H from SMEM and compute pred | ||
| - Cross-warp reduce pred (uses pred_slot) | ||
| - Update H with delta | ||
| - Load Q and compute output | ||
| Returns: out (partial output, not yet reduced) | ||
| """ | ||
| # Load K for this token | ||
| load_kq_chunk_from_smem(k_sh, kq_chunk, k_base) | ||
|
|
||
| # Decay H from SMEM and compute pred = H * K | ||
| pred = decay_h_from_smem_and_compute_pred( | ||
| h_sh_chunk_curr, h_chunk, kq_chunk, g_exp, lane_idx, k_base | ||
| ) | ||
|
|
||
| # Reduce pred across warps (slot 0 for first token) | ||
| pred_final = cross_warp_reduce_single( | ||
| reduce_sh, pred_slot, warp_idx, lane_idx, pred | ||
| ) | ||
|
|
||
| # Compute delta and update H | ||
| v_delta = (v_sh[v_offset + lane_idx] - pred_final) * beta | ||
| update_h_with_delta(h_chunk, kq_chunk, v_delta) | ||
|
|
||
| # Load Q and compute output | ||
| load_kq_chunk_from_smem(q_sh, kq_chunk, k_base) | ||
| out = compute_output(h_chunk, kq_chunk) | ||
|
|
||
| return out |
There was a problem hiding this comment.
Unused o_head parameter in process_first_token.
The o_head parameter (line 320) is never referenced in the function body. The first token's output is returned as out and stored by the next token's processing function. This dead parameter adds confusion and wastes a register.
Proposed fix
`@cute.jit`
def process_first_token(
h_sh_chunk_curr,
h_chunk,
kq_chunk,
k_sh,
q_sh,
v_sh,
reduce_sh,
- o_head,
g_exp,
beta,
v_offset,
pred_slot,
warp_idx,
lane_idx,
k_base,
):And update all call sites accordingly (in process_vchunk_unified_234).
🧰 Tools
🪛 Ruff (0.15.0)
[warning] 320-320: Unused function argument: o_head
(ARG001)
🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/gated_delta_rule.py` around lines 312 - 359, Remove the
dead o_head parameter from the process_first_token function signature and all
call sites; update the function signature for process_first_token to omit o_head
and adjust every invocation (notably in process_vchunk_unified_234) to stop
passing that argument, ensuring argument order and remaining parameters
(h_sh_chunk_curr, h_chunk, kq_chunk, k_sh, q_sh, v_sh, reduce_sh, g_exp, beta,
v_offset, pred_slot, warp_idx, lane_idx, k_base) remain unchanged; rebuild and
run tests to confirm no other call sites reference o_head.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@flashinfer/cute_dsl/gated_delta_rule.py`:
- Around line 1179-1182: The comment is incorrect: reduce_sh allocated via
smem.allocate_tensor(cutlass.Float32, cute.make_layout((8, 32, 4), stride=(128,
4, 1))) yields 4-way bank conflicts for reduce_sh[slot, lane_idx, warp_idx];
either update the comment to state it’s 4-way conflicted or change the layout to
a bank-conflict-free interleaving (e.g., replace the stride tuple with an
interleaved BCF stride such as using 33-element pitch for the lane dimension
like stride=(33*4, 1, 33) or equivalent) so accesses across lane_idx map to
unique banks; modify the smem.allocate_tensor call for reduce_sh accordingly and
keep the symbol names reduce_sh, smem.allocate_tensor, cute.make_layout and
cutlass.Float32 to locate the change.
🧹 Nitpick comments (2)
flashinfer/cute_dsl/gated_delta_rule.py (2)
696-1097: Seqlen=1 kernel: 4 nearly-identical chunk blocks (~300 lines of repeated code).Chunks 0–3 (lines 807–1096) repeat the same decay → pred → reduce → delta → output → H-writeback pattern with only the SMEM buffer name and V offset changing. While inlining is understandable for pipelining (the async load/store interleaving differs per chunk), the core compute block could be extracted into a helper that takes the chunk's H SMEM buffer and V offset, similar to how
process_vchunk_unified_234factored the seqlen≥2 path. This would cut ~200 lines with no runtime cost since CuTe-DSL will inline@cute.jithelpers at compile time.
1998-2007: Low-BS dispatch threshold (B <= 4) is hard-coded without documented rationale.The T=1 path selects the low-BS kernel (4 CTAs per batch×head) when
B <= 4. This threshold likely comes from empirical benchmarking, but there's no comment or constant explaining the crossover point. If the benchmark characteristics change (e.g., different GPU, different HV), this threshold could become stale. Consider extracting4to a named constant (e.g.,_LOWBS_THRESHOLD = 4) with a brief comment about the rationale.
| # Bank-conflict-free reduce_sh: [slot, lane_idx, warp_idx] | ||
| reduce_sh = smem.allocate_tensor( | ||
| cutlass.Float32, cute.make_layout((8, 32, 4), stride=(128, 4, 1)) | ||
| ) |
There was a problem hiding this comment.
reduce_sh layout has 4-way bank conflicts despite the "bank-conflict-free" comment.
With stride (128, 4, 1), reduce_sh[slot, lane_idx, warp_idx] maps to offset slot*128 + lane_idx*4 + warp_idx. Within a single warp reading a fixed warp_idx, the bank is (lane_idx*4 + warp_idx) % 32, which means lanes 0, 8, 16, 24 (and similar groups) map to the same bank — yielding 4-way conflicts on both writes and reads.
Compare with the seqlen=1 kernel's pred_sh/out_sh at lines 760–764, which uses stride (1, 32) — truly bank-conflict-free since lane_idx + warp_idx*32 gives each lane a unique bank.
Consider either updating the comment to reflect the actual conflict level, or switching to a BCF layout (e.g., (8, 32, 4) with stride (33*4, 1, 33) or similar interleaving).
🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/gated_delta_rule.py` around lines 1179 - 1182, The
comment is incorrect: reduce_sh allocated via
smem.allocate_tensor(cutlass.Float32, cute.make_layout((8, 32, 4), stride=(128,
4, 1))) yields 4-way bank conflicts for reduce_sh[slot, lane_idx, warp_idx];
either update the comment to state it’s 4-way conflicted or change the layout to
a bank-conflict-free interleaving (e.g., replace the stride tuple with an
interleaved BCF stride such as using 33-element pitch for the lane dimension
like stride=(33*4, 1, 33) or equivalent) so accesses across lane_idx map to
unique banks; modify the smem.allocate_tensor call for reduce_sh accordingly and
keep the symbol names reduce_sh, smem.allocate_tensor, cute.make_layout and
cutlass.Float32 to locate the change.
…mproved_cutedsl - gdn_decode.py: optional backend for pretranspose when bf16 state, T<=4, K=V=128; dispatch to cute_dsl gated_delta_rule. - bench_gdn_decode.py: rename improved_cutedsl to gdn_decode_klast_bf16_state (--version, wrapper, result keys). - test_decode_delta_rule.py: same rename; add test_pretranspose_api_uses_gdn_decode_klast_bf16_state for API dispatch. Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
|
cute dsl bf16
|
|
/bot run |
|
[FAILED] Pipeline #43822304: 11/20 passed |
| - Bank-conflict-free cross-warp reductions | ||
| - Async H memory loading with aggressive pipelining | ||
| - BF16 tensors with FP32 compute for numerical stability | ||
| - GQA (grouped-query attention) support with configurable H (query) and HV (value) heads |
There was a problem hiding this comment.
GQA: HQ = <IntegerRatio> * HV
GVA: <IntegerRatio> * HQ = HV, I confirmed this naming in a group chat with GDN paper author :)
There was a problem hiding this comment.
@guangyunh-nv sorry do you mean it is called GVA and not GQA?
|
Thanks to everyone for clarifying, this helped :) @vadiklyutiy @guangyunh-nv @ameynaik-hub |
|
Out of curisity, why the perf degenerate when batch goes from 1 to 4.
I'd assume you have not reached the parallelism limit and bandwidth limit, then the time should remain a constant. |
@guangyunh-nv for BS <=4 I use 4 ctas per head. so for BS=1 it has 324 = 128 CTAs and for BS=4 it has 324*4 = 512 CTAs |
tests/gdn/reference_delta_rule.py
Outdated
| @@ -137,6 +137,7 @@ def blockwise_linear_attention( | |||
| | torch.Tensor = 1.0, # float or tensor with num_elems == num_qo_heads | |||
| decay_exponent_offset=0, | |||
| kv_dtype: torch.dtype = torch.float32, | |||
There was a problem hiding this comment.
I think you should remove kv_dtype, it is named as kv_dtype because in some earlier papers call the state directly as KV. It is not dtype for K and V
There was a problem hiding this comment.
done, can you please confirm if the change is okay? Thanks!
- Consolidate to single state_dtype parameter across all reference functions
- Remove duplicate kv_dtype parameter from blockwise_linear_attention(),
delta_rule(), and blockwise_delta_rule()
- Update test_prefill_delta_rule.py to use state_dtype consistently
- Remove benchmark_gated_delta_rule.py from git tracking (keep locally)
- Add to .gitignore for local development use only
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
yzh119
left a comment
There was a problem hiding this comment.
Overall LGTM, some minor nits.
| q: torch.Tensor, # [B, T, H_Q, K] where T=1,2,3,4 | ||
| k: torch.Tensor, # [B, T, H_K, K] | ||
| v: torch.Tensor, # [B, T, HV, V] | ||
| state: torch.Tensor, # [B, HV, V, K] - K-fast layout (pretranspose) |
There was a problem hiding this comment.
| state: torch.Tensor, # [B, HV, V, K] - K-fast layout (pretranspose) | |
| state: torch.Tensor, # [B, HV, V, K] - K-last layout (pretranspose) |
There was a problem hiding this comment.
Can we moved it to flashinfer/gdn_decode.py (or making it a module)? We don't want to keep the directoryflashinfer/cute_dsl in the future.
Implements high-performance Gated Delta Rule linear attention kernel supporting fixed sequence lengths T=1, T=2, T=3, T=4 using NVIDIA CuTe-DSL.
Key features:
Also includes:
📌 Description
Implements high-performance Gated Delta Rule linear attention kernel supporting fixed sequence lengths T=1, T=2, T=3, T=4 using NVIDIA CuTe-DSL.
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Summary by CodeRabbit
New Features
Benchmarks
Tests