|
| 1 | +""" |
| 2 | +Copyright (c) 2025 by FlashInfer team. |
| 3 | +
|
| 4 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +you may not use this file except in compliance with the License. |
| 6 | +You may obtain a copy of the License at |
| 7 | +
|
| 8 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +
|
| 10 | +Unless required by applicable law or agreed to in writing, software |
| 11 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +See the License for the specific language governing permissions and |
| 14 | +limitations under the License. |
| 15 | +""" |
| 16 | + |
| 17 | +import argparse |
| 18 | +import numpy as np |
| 19 | +import torch |
| 20 | + |
| 21 | +from flashinfer.gdn_prefill import chunk_gated_delta_rule |
| 22 | +from flashinfer.testing.utils import bench_gpu_time |
| 23 | + |
| 24 | + |
| 25 | +def gdn_flops( |
| 26 | + total_seq_len: int, |
| 27 | + num_q_heads: int, |
| 28 | + num_k_heads: int, |
| 29 | + num_v_heads: int, |
| 30 | + head_size: int, |
| 31 | + num_seqs: int, |
| 32 | +) -> int: |
| 33 | + """ |
| 34 | + Calculate FLOPs for Gated Delta Rule (GDN) attention. |
| 35 | +
|
| 36 | + Delta Rule formula: |
| 37 | + state_t = alpha_t * state_{t-1} + beta_t * (k_t @ v_t^T) |
| 38 | + output_t = q_t @ state_t |
| 39 | +
|
| 40 | + Matrix multiplications per token per head: |
| 41 | + 1. k @ v^T (outer product): 2 * d^2 FLOPs |
| 42 | + 2. q @ state: 2 * d^2 FLOPs |
| 43 | +
|
| 44 | + Note: alpha/beta gating are element-wise scalar multiplications, |
| 45 | + not counted in TFLOPS. |
| 46 | + """ |
| 47 | + num_o_heads = max(num_q_heads, num_v_heads) |
| 48 | + |
| 49 | + # k @ v^T (outer product): 2 * d^2 per token per head |
| 50 | + outer_product_flops = 2 * total_seq_len * num_o_heads * head_size * head_size |
| 51 | + |
| 52 | + # q @ state: 2 * d^2 per token per head |
| 53 | + output_flops = 2 * total_seq_len * num_o_heads * head_size * head_size |
| 54 | + |
| 55 | + total_flops = outer_product_flops + output_flops |
| 56 | + return total_flops |
| 57 | + |
| 58 | + |
| 59 | +def gdn_bytes( |
| 60 | + total_seq_len: int, |
| 61 | + num_q_heads: int, |
| 62 | + num_k_heads: int, |
| 63 | + num_v_heads: int, |
| 64 | + head_size: int, |
| 65 | + num_seqs: int, |
| 66 | + dtype: torch.dtype, |
| 67 | +) -> int: |
| 68 | + """ |
| 69 | + Calculate memory bytes for GDN attention. |
| 70 | +
|
| 71 | + Includes: |
| 72 | + - Q, K, V tensors (input) |
| 73 | + - Output tensor |
| 74 | + - State tensor (float32) |
| 75 | + - Alpha, Beta tensors (optional, float32) |
| 76 | + """ |
| 77 | + num_o_heads = max(num_q_heads, num_v_heads) |
| 78 | + num_sab_heads = num_o_heads |
| 79 | + elem_size = dtype.itemsize |
| 80 | + |
| 81 | + # Input tensors |
| 82 | + q_bytes = total_seq_len * num_q_heads * head_size * elem_size |
| 83 | + k_bytes = total_seq_len * num_k_heads * head_size * elem_size |
| 84 | + v_bytes = total_seq_len * num_v_heads * head_size * elem_size |
| 85 | + |
| 86 | + # Output tensor |
| 87 | + o_bytes = total_seq_len * num_o_heads * head_size * elem_size |
| 88 | + |
| 89 | + # State tensor (float32) |
| 90 | + state_bytes = num_seqs * num_sab_heads * head_size * head_size * 4 |
| 91 | + |
| 92 | + # Alpha and Beta (float32) |
| 93 | + alpha_bytes = total_seq_len * num_sab_heads * 4 |
| 94 | + beta_bytes = total_seq_len * num_sab_heads * 4 |
| 95 | + |
| 96 | + total_bytes = ( |
| 97 | + q_bytes + k_bytes + v_bytes + o_bytes + state_bytes + alpha_bytes + beta_bytes |
| 98 | + ) |
| 99 | + return total_bytes |
| 100 | + |
| 101 | + |
| 102 | +def bench_gdn_prefill( |
| 103 | + batch_size: int, |
| 104 | + seq_len: int, |
| 105 | + num_q_heads: int, |
| 106 | + num_k_heads: int, |
| 107 | + num_v_heads: int, |
| 108 | + head_size: int, |
| 109 | + dtype: torch.dtype, |
| 110 | + use_alpha: bool = True, |
| 111 | + use_beta: bool = True, |
| 112 | +): |
| 113 | + """Benchmark GDN prefill kernel.""" |
| 114 | + total_seq_len = batch_size * seq_len |
| 115 | + num_o_heads = max(num_q_heads, num_v_heads) |
| 116 | + num_sab_heads = num_o_heads |
| 117 | + |
| 118 | + # Create inputs |
| 119 | + q = torch.randn(total_seq_len, num_q_heads, head_size, dtype=dtype, device="cuda") |
| 120 | + k = torch.randn(total_seq_len, num_k_heads, head_size, dtype=dtype, device="cuda") |
| 121 | + # L2 normalize k for numerical stability |
| 122 | + k = torch.nn.functional.normalize(k, p=2.0, dim=-1) |
| 123 | + v = torch.randn(total_seq_len, num_v_heads, head_size, dtype=dtype, device="cuda") |
| 124 | + |
| 125 | + cu_seqlens = torch.arange( |
| 126 | + 0, batch_size * seq_len + 1, seq_len, dtype=torch.int64, device="cuda" |
| 127 | + ) |
| 128 | + |
| 129 | + alpha = ( |
| 130 | + torch.rand(total_seq_len, num_sab_heads, dtype=torch.float32, device="cuda") |
| 131 | + if use_alpha |
| 132 | + else None |
| 133 | + ) |
| 134 | + beta = ( |
| 135 | + torch.rand(total_seq_len, num_sab_heads, dtype=torch.float32, device="cuda") |
| 136 | + if use_beta |
| 137 | + else None |
| 138 | + ) |
| 139 | + |
| 140 | + # Pre-allocate outputs |
| 141 | + output = torch.empty( |
| 142 | + total_seq_len, num_o_heads, head_size, dtype=dtype, device="cuda" |
| 143 | + ) |
| 144 | + output_state = torch.empty( |
| 145 | + batch_size, |
| 146 | + num_sab_heads, |
| 147 | + head_size, |
| 148 | + head_size, |
| 149 | + dtype=torch.float32, |
| 150 | + device="cuda", |
| 151 | + ) |
| 152 | + |
| 153 | + # Warmup |
| 154 | + chunk_gated_delta_rule( |
| 155 | + q, k, v, alpha, beta, None, None, True, cu_seqlens, False, output, output_state |
| 156 | + ) |
| 157 | + torch.cuda.synchronize() |
| 158 | + |
| 159 | + # Benchmark |
| 160 | + times = bench_gpu_time( |
| 161 | + lambda: chunk_gated_delta_rule( |
| 162 | + q, |
| 163 | + k, |
| 164 | + v, |
| 165 | + alpha, |
| 166 | + beta, |
| 167 | + None, |
| 168 | + None, |
| 169 | + True, |
| 170 | + cu_seqlens, |
| 171 | + False, |
| 172 | + output, |
| 173 | + output_state, |
| 174 | + ), |
| 175 | + dry_run_time_ms=100, |
| 176 | + repeat_time_ms=1000, |
| 177 | + enable_cupti=True, |
| 178 | + ) |
| 179 | + |
| 180 | + median_ms = np.median(times) |
| 181 | + |
| 182 | + # Calculate metrics |
| 183 | + flops = gdn_flops( |
| 184 | + total_seq_len, num_q_heads, num_k_heads, num_v_heads, head_size, batch_size |
| 185 | + ) |
| 186 | + bytes_accessed = gdn_bytes( |
| 187 | + total_seq_len, |
| 188 | + num_q_heads, |
| 189 | + num_k_heads, |
| 190 | + num_v_heads, |
| 191 | + head_size, |
| 192 | + batch_size, |
| 193 | + dtype, |
| 194 | + ) |
| 195 | + |
| 196 | + tflops = flops / median_ms / 1e9 |
| 197 | + tb_per_sec = bytes_accessed / median_ms / 1e9 |
| 198 | + |
| 199 | + # Get device info for bandwidth calculation |
| 200 | + props = torch.cuda.get_device_properties(0) |
| 201 | + props.total_memory * 2 / 1e12 # Approximate peak bandwidth |
| 202 | + |
| 203 | + return { |
| 204 | + "batch_size": batch_size, |
| 205 | + "seq_len": seq_len, |
| 206 | + "num_q_heads": num_q_heads, |
| 207 | + "num_k_heads": num_k_heads, |
| 208 | + "num_v_heads": num_v_heads, |
| 209 | + "head_size": head_size, |
| 210 | + "dtype": str(dtype).replace("torch.", ""), |
| 211 | + "median_ms": median_ms, |
| 212 | + "tflops": tflops, |
| 213 | + "tb_per_sec": tb_per_sec, |
| 214 | + } |
| 215 | + |
| 216 | + |
| 217 | +def main(): |
| 218 | + parser = argparse.ArgumentParser(description="Benchmark GDN Prefill Kernel") |
| 219 | + parser.add_argument("--batch-size", type=int, nargs="+", default=[1, 4, 16, 64]) |
| 220 | + parser.add_argument("--seq-len", type=int, nargs="+", default=[128, 256, 512, 1024]) |
| 221 | + parser.add_argument("--num-q-heads", type=int, default=16) |
| 222 | + parser.add_argument("--num-k-heads", type=int, default=16) |
| 223 | + parser.add_argument("--num-v-heads", type=int, default=32) |
| 224 | + parser.add_argument("--head-size", type=int, default=128) |
| 225 | + parser.add_argument( |
| 226 | + "--dtype", type=str, choices=["float16", "bfloat16"], default="bfloat16" |
| 227 | + ) |
| 228 | + parser.add_argument( |
| 229 | + "--preset", |
| 230 | + type=str, |
| 231 | + choices=["qwen3-next", "custom"], |
| 232 | + default="custom", |
| 233 | + help="Use preset config. qwen3-next: q=k=16, v=32, d=128", |
| 234 | + ) |
| 235 | + args = parser.parse_args() |
| 236 | + |
| 237 | + # Apply preset configurations |
| 238 | + if args.preset == "qwen3-next": |
| 239 | + # Qwen3-Next-80B-A3B linear attention config (GVA) |
| 240 | + args.num_q_heads = 16 |
| 241 | + args.num_k_heads = 16 |
| 242 | + args.num_v_heads = 32 |
| 243 | + args.head_size = 128 |
| 244 | + |
| 245 | + # Check SM90 support |
| 246 | + device_capability = torch.cuda.get_device_capability() |
| 247 | + if device_capability[0] < 9: |
| 248 | + print(f"Current device capability: {device_capability}") |
| 249 | + print("GDN requires SM90 (Hopper) or later. Exiting...") |
| 250 | + return |
| 251 | + |
| 252 | + dtype = getattr(torch, args.dtype) |
| 253 | + |
| 254 | + print( |
| 255 | + f"GDN Prefill Benchmark (heads: q={args.num_q_heads}, k={args.num_k_heads}, v={args.num_v_heads}, d={args.head_size}, dtype={args.dtype})" |
| 256 | + ) |
| 257 | + print("-" * 100) |
| 258 | + print(f"{'batch':>6} {'seq_len':>8} {'time(ms)':>10} {'TFLOPS':>10} {'TB/s':>10}") |
| 259 | + print("-" * 100) |
| 260 | + |
| 261 | + for batch_size in args.batch_size: |
| 262 | + for seq_len in args.seq_len: |
| 263 | + result = bench_gdn_prefill( |
| 264 | + batch_size=batch_size, |
| 265 | + seq_len=seq_len, |
| 266 | + num_q_heads=args.num_q_heads, |
| 267 | + num_k_heads=args.num_k_heads, |
| 268 | + num_v_heads=args.num_v_heads, |
| 269 | + head_size=args.head_size, |
| 270 | + dtype=dtype, |
| 271 | + ) |
| 272 | + print( |
| 273 | + f"{result['batch_size']:>6} {result['seq_len']:>8} " |
| 274 | + f"{result['median_ms']:>10.3f} {result['tflops']:>10.2f} " |
| 275 | + f"{result['tb_per_sec']:>10.2f}" |
| 276 | + ) |
| 277 | + |
| 278 | + print("-" * 100) |
| 279 | + |
| 280 | + |
| 281 | +if __name__ == "__main__": |
| 282 | + main() |
0 commit comments