Skip to content

Commit aff0d3b

Browse files
guangyunh-nvjiahancyzh119
authored
feat: add GDN Attention (flashinfer-ai#2276)
<!-- .github/pull_request_template.md --> ## 📌 Description This PR adds implementation for Gated Delta Rule (or Gated Delta Net) on Hopper architecture to better support Qwen-next like architecture. ## 🔍 Related Issues flashinfer-ai#1690 ## 🚀 Pull Request Checklist ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> Thanks @jiahanc for initiating the kernel integration and implementing the API. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * SM90-optimized Gated Delta Rule (GDN) prefill: Python API (chunk_gated_delta_rule), host launcher, and FFI export; supports optional alpha/beta gating and returns output and final state. * **Benchmarks & Tests** * New GPU benchmark for GDN prefill reporting runtime, TFLOPs and bandwidth. * Added reference implementations and comprehensive tests validating prefill, chunked prefill, and delta-rule behavior. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: jiahanc <[email protected]> Co-authored-by: jiahanc <[email protected]> Co-authored-by: Zihao Ye <[email protected]>
1 parent cda8f3f commit aff0d3b

31 files changed

+5887
-0
lines changed

benchmarks/bench_gdn_prefill.py

Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
"""
2+
Copyright (c) 2025 by FlashInfer team.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import argparse
18+
import numpy as np
19+
import torch
20+
21+
from flashinfer.gdn_prefill import chunk_gated_delta_rule
22+
from flashinfer.testing.utils import bench_gpu_time
23+
24+
25+
def gdn_flops(
26+
total_seq_len: int,
27+
num_q_heads: int,
28+
num_k_heads: int,
29+
num_v_heads: int,
30+
head_size: int,
31+
num_seqs: int,
32+
) -> int:
33+
"""
34+
Calculate FLOPs for Gated Delta Rule (GDN) attention.
35+
36+
Delta Rule formula:
37+
state_t = alpha_t * state_{t-1} + beta_t * (k_t @ v_t^T)
38+
output_t = q_t @ state_t
39+
40+
Matrix multiplications per token per head:
41+
1. k @ v^T (outer product): 2 * d^2 FLOPs
42+
2. q @ state: 2 * d^2 FLOPs
43+
44+
Note: alpha/beta gating are element-wise scalar multiplications,
45+
not counted in TFLOPS.
46+
"""
47+
num_o_heads = max(num_q_heads, num_v_heads)
48+
49+
# k @ v^T (outer product): 2 * d^2 per token per head
50+
outer_product_flops = 2 * total_seq_len * num_o_heads * head_size * head_size
51+
52+
# q @ state: 2 * d^2 per token per head
53+
output_flops = 2 * total_seq_len * num_o_heads * head_size * head_size
54+
55+
total_flops = outer_product_flops + output_flops
56+
return total_flops
57+
58+
59+
def gdn_bytes(
60+
total_seq_len: int,
61+
num_q_heads: int,
62+
num_k_heads: int,
63+
num_v_heads: int,
64+
head_size: int,
65+
num_seqs: int,
66+
dtype: torch.dtype,
67+
) -> int:
68+
"""
69+
Calculate memory bytes for GDN attention.
70+
71+
Includes:
72+
- Q, K, V tensors (input)
73+
- Output tensor
74+
- State tensor (float32)
75+
- Alpha, Beta tensors (optional, float32)
76+
"""
77+
num_o_heads = max(num_q_heads, num_v_heads)
78+
num_sab_heads = num_o_heads
79+
elem_size = dtype.itemsize
80+
81+
# Input tensors
82+
q_bytes = total_seq_len * num_q_heads * head_size * elem_size
83+
k_bytes = total_seq_len * num_k_heads * head_size * elem_size
84+
v_bytes = total_seq_len * num_v_heads * head_size * elem_size
85+
86+
# Output tensor
87+
o_bytes = total_seq_len * num_o_heads * head_size * elem_size
88+
89+
# State tensor (float32)
90+
state_bytes = num_seqs * num_sab_heads * head_size * head_size * 4
91+
92+
# Alpha and Beta (float32)
93+
alpha_bytes = total_seq_len * num_sab_heads * 4
94+
beta_bytes = total_seq_len * num_sab_heads * 4
95+
96+
total_bytes = (
97+
q_bytes + k_bytes + v_bytes + o_bytes + state_bytes + alpha_bytes + beta_bytes
98+
)
99+
return total_bytes
100+
101+
102+
def bench_gdn_prefill(
103+
batch_size: int,
104+
seq_len: int,
105+
num_q_heads: int,
106+
num_k_heads: int,
107+
num_v_heads: int,
108+
head_size: int,
109+
dtype: torch.dtype,
110+
use_alpha: bool = True,
111+
use_beta: bool = True,
112+
):
113+
"""Benchmark GDN prefill kernel."""
114+
total_seq_len = batch_size * seq_len
115+
num_o_heads = max(num_q_heads, num_v_heads)
116+
num_sab_heads = num_o_heads
117+
118+
# Create inputs
119+
q = torch.randn(total_seq_len, num_q_heads, head_size, dtype=dtype, device="cuda")
120+
k = torch.randn(total_seq_len, num_k_heads, head_size, dtype=dtype, device="cuda")
121+
# L2 normalize k for numerical stability
122+
k = torch.nn.functional.normalize(k, p=2.0, dim=-1)
123+
v = torch.randn(total_seq_len, num_v_heads, head_size, dtype=dtype, device="cuda")
124+
125+
cu_seqlens = torch.arange(
126+
0, batch_size * seq_len + 1, seq_len, dtype=torch.int64, device="cuda"
127+
)
128+
129+
alpha = (
130+
torch.rand(total_seq_len, num_sab_heads, dtype=torch.float32, device="cuda")
131+
if use_alpha
132+
else None
133+
)
134+
beta = (
135+
torch.rand(total_seq_len, num_sab_heads, dtype=torch.float32, device="cuda")
136+
if use_beta
137+
else None
138+
)
139+
140+
# Pre-allocate outputs
141+
output = torch.empty(
142+
total_seq_len, num_o_heads, head_size, dtype=dtype, device="cuda"
143+
)
144+
output_state = torch.empty(
145+
batch_size,
146+
num_sab_heads,
147+
head_size,
148+
head_size,
149+
dtype=torch.float32,
150+
device="cuda",
151+
)
152+
153+
# Warmup
154+
chunk_gated_delta_rule(
155+
q, k, v, alpha, beta, None, None, True, cu_seqlens, False, output, output_state
156+
)
157+
torch.cuda.synchronize()
158+
159+
# Benchmark
160+
times = bench_gpu_time(
161+
lambda: chunk_gated_delta_rule(
162+
q,
163+
k,
164+
v,
165+
alpha,
166+
beta,
167+
None,
168+
None,
169+
True,
170+
cu_seqlens,
171+
False,
172+
output,
173+
output_state,
174+
),
175+
dry_run_time_ms=100,
176+
repeat_time_ms=1000,
177+
enable_cupti=True,
178+
)
179+
180+
median_ms = np.median(times)
181+
182+
# Calculate metrics
183+
flops = gdn_flops(
184+
total_seq_len, num_q_heads, num_k_heads, num_v_heads, head_size, batch_size
185+
)
186+
bytes_accessed = gdn_bytes(
187+
total_seq_len,
188+
num_q_heads,
189+
num_k_heads,
190+
num_v_heads,
191+
head_size,
192+
batch_size,
193+
dtype,
194+
)
195+
196+
tflops = flops / median_ms / 1e9
197+
tb_per_sec = bytes_accessed / median_ms / 1e9
198+
199+
# Get device info for bandwidth calculation
200+
props = torch.cuda.get_device_properties(0)
201+
props.total_memory * 2 / 1e12 # Approximate peak bandwidth
202+
203+
return {
204+
"batch_size": batch_size,
205+
"seq_len": seq_len,
206+
"num_q_heads": num_q_heads,
207+
"num_k_heads": num_k_heads,
208+
"num_v_heads": num_v_heads,
209+
"head_size": head_size,
210+
"dtype": str(dtype).replace("torch.", ""),
211+
"median_ms": median_ms,
212+
"tflops": tflops,
213+
"tb_per_sec": tb_per_sec,
214+
}
215+
216+
217+
def main():
218+
parser = argparse.ArgumentParser(description="Benchmark GDN Prefill Kernel")
219+
parser.add_argument("--batch-size", type=int, nargs="+", default=[1, 4, 16, 64])
220+
parser.add_argument("--seq-len", type=int, nargs="+", default=[128, 256, 512, 1024])
221+
parser.add_argument("--num-q-heads", type=int, default=16)
222+
parser.add_argument("--num-k-heads", type=int, default=16)
223+
parser.add_argument("--num-v-heads", type=int, default=32)
224+
parser.add_argument("--head-size", type=int, default=128)
225+
parser.add_argument(
226+
"--dtype", type=str, choices=["float16", "bfloat16"], default="bfloat16"
227+
)
228+
parser.add_argument(
229+
"--preset",
230+
type=str,
231+
choices=["qwen3-next", "custom"],
232+
default="custom",
233+
help="Use preset config. qwen3-next: q=k=16, v=32, d=128",
234+
)
235+
args = parser.parse_args()
236+
237+
# Apply preset configurations
238+
if args.preset == "qwen3-next":
239+
# Qwen3-Next-80B-A3B linear attention config (GVA)
240+
args.num_q_heads = 16
241+
args.num_k_heads = 16
242+
args.num_v_heads = 32
243+
args.head_size = 128
244+
245+
# Check SM90 support
246+
device_capability = torch.cuda.get_device_capability()
247+
if device_capability[0] < 9:
248+
print(f"Current device capability: {device_capability}")
249+
print("GDN requires SM90 (Hopper) or later. Exiting...")
250+
return
251+
252+
dtype = getattr(torch, args.dtype)
253+
254+
print(
255+
f"GDN Prefill Benchmark (heads: q={args.num_q_heads}, k={args.num_k_heads}, v={args.num_v_heads}, d={args.head_size}, dtype={args.dtype})"
256+
)
257+
print("-" * 100)
258+
print(f"{'batch':>6} {'seq_len':>8} {'time(ms)':>10} {'TFLOPS':>10} {'TB/s':>10}")
259+
print("-" * 100)
260+
261+
for batch_size in args.batch_size:
262+
for seq_len in args.seq_len:
263+
result = bench_gdn_prefill(
264+
batch_size=batch_size,
265+
seq_len=seq_len,
266+
num_q_heads=args.num_q_heads,
267+
num_k_heads=args.num_k_heads,
268+
num_v_heads=args.num_v_heads,
269+
head_size=args.head_size,
270+
dtype=dtype,
271+
)
272+
print(
273+
f"{result['batch_size']:>6} {result['seq_len']:>8} "
274+
f"{result['median_ms']:>10.3f} {result['tflops']:>10.2f} "
275+
f"{result['tb_per_sec']:>10.2f}"
276+
)
277+
278+
print("-" * 100)
279+
280+
281+
if __name__ == "__main__":
282+
main()

0 commit comments

Comments
 (0)