-
Notifications
You must be signed in to change notification settings - Fork 833
Ameyn/gdn decode cutedsl kernel #2498
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
yzh119
merged 12 commits into
flashinfer-ai:main
from
ameynaik-hub:ameyn/gdn-decode-cutedsl-kernel
Feb 17, 2026
Merged
Changes from 2 commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
a5a2bac
Add Gated Delta Rule CuTe-DSL kernel for decode-phase inference
ameynaik-hub 3ac695d
Fix ruff B905 linter error
ameynaik-hub e52e60a
Add bf16 h-state reference, fastmath/rsqrt/L1-bypass optimizations to…
ameynaik-hub 8eac4d8
gated_delta_rule: use (32,4) smem layout for pred_sh/out_sh for coale…
ameynaik-hub ef0034f
gated_delta_rule: add LowBS-1 kernel for T=1, BS<=4
ameynaik-hub 4fffb16
gated_delta_rule: enable tvm-ffi
ameynaik-hub 8fa0e9b
gdn_decode: add gdn_decode_klast_bf16_state backend and rename from i…
ameynaik-hub 8a87cc8
Merge branch 'flashinfer-ai:main' into ameyn/gdn-decode-cutedsl-kernel
ameynaik-hub 8d1e6b9
Refactor: Consolidate dtype parameters in reference implementation
ameynaik-hub 696dca9
Update benchmarks/bench_gdn_decode.py
ameynaik-hub 4c2c1b5
fix: Add parameter validation and improve cache key in gated_delta_rule
ameynaik-hub 7c5b004
refactor: Move GDN CuTe DSL kernel to dedicated gdn_kernels module
ameynaik-hub File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,167 @@ | ||
| """ | ||
| Benchmark: Gated Delta Rule CuTe-DSL Kernel | ||
|
|
||
| Simple benchmark showing duration across batch sizes and sequence lengths (T=1,2,3,4). | ||
| """ | ||
|
|
||
| import math | ||
| import statistics | ||
| import torch | ||
|
|
||
|
|
||
| def get_l2_cache_size(): | ||
| """Get L2 cache size in bytes for the current GPU.""" | ||
| return torch.cuda.get_device_properties(0).L2_cache_size | ||
|
|
||
|
|
||
| def benchmark( | ||
| func, num_iterations=100, n_warmup=10, flush_l2=True, use_dummy_matmul=True | ||
| ): | ||
| """ | ||
| Benchmark a kernel with L2 flushing and return median time in microseconds. | ||
|
|
||
| Args: | ||
| func: Function to benchmark | ||
| num_iterations: Number of timed iterations | ||
| n_warmup: Number of warmup iterations | ||
| flush_l2: Whether to flush L2 cache before each iteration | ||
| use_dummy_matmul: Whether to use dummy matmul for short-lived kernels | ||
| """ | ||
| l2_size = get_l2_cache_size() | ||
| cache_flush = torch.empty(l2_size, dtype=torch.uint8, device="cuda") | ||
|
|
||
| # Dummy matmul for short-lived kernels (fills GPU pipeline so CUDA events record properly) | ||
| if use_dummy_matmul: | ||
| A = torch.randn(4096, 4096, dtype=torch.float32, device="cuda") | ||
| B = torch.randn(4096, 4096, dtype=torch.float32, device="cuda") | ||
| _ = A @ B # Warm up cuBLAS | ||
|
|
||
| # Warmup | ||
| for _ in range(n_warmup): | ||
| if flush_l2: | ||
| cache_flush.zero_() | ||
| func() | ||
| torch.cuda.synchronize() | ||
|
|
||
| # Benchmark | ||
| start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_iterations)] | ||
| end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_iterations)] | ||
|
|
||
| for i in range(num_iterations): | ||
| if flush_l2: | ||
| cache_flush.zero_() | ||
| if use_dummy_matmul: | ||
| _ = A @ B # Dummy work to ensure events record properly for short kernels | ||
| start_events[i].record() | ||
| func() | ||
| end_events[i].record() | ||
|
|
||
| torch.cuda.synchronize() | ||
| times_us = [ | ||
| s.elapsed_time(e) * 1000 for s, e in zip(start_events, end_events, strict=True) | ||
| ] | ||
| return statistics.median(times_us) | ||
|
|
||
|
|
||
| def create_inputs(B, T, H=16, HV=32, K=128, V=128): | ||
| """Create test inputs.""" | ||
| return { | ||
| "q": torch.randn(B, T, H, K, device="cuda", dtype=torch.bfloat16), | ||
| "k": torch.randn(B, T, H, K, device="cuda", dtype=torch.bfloat16), | ||
| "v": torch.randn(B, T, HV, V, device="cuda", dtype=torch.bfloat16), | ||
| "a": torch.randn(B, T, HV, device="cuda", dtype=torch.bfloat16) * 0.1, | ||
| "b": torch.randn(B, T, HV, device="cuda", dtype=torch.bfloat16), | ||
| "A_log": torch.randn(HV, device="cuda", dtype=torch.float32) * 0.1, | ||
| "dt_bias": torch.randn(HV, device="cuda", dtype=torch.float32) * 0.1, | ||
| "state": torch.randn(B, HV, V, K, device="cuda", dtype=torch.bfloat16), | ||
| "scale": 1.0 / math.sqrt(K), | ||
| } | ||
|
|
||
|
|
||
| def main(): | ||
| from gated_delta_rule import gated_delta_rule | ||
|
|
||
| print("=" * 70) | ||
| print("Gated Delta Rule CuTe-DSL Kernel Benchmark") | ||
| print("Config: H=16, HV=32, K=128, V=128, bfloat16") | ||
| print("=" * 70) | ||
|
|
||
| batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] | ||
| seqlens = [1, 2, 3, 4] | ||
| num_iterations = 100 | ||
|
|
||
| # Results storage | ||
| results = {T: {} for T in seqlens} | ||
|
|
||
| # Benchmark each configuration | ||
| for T in seqlens: | ||
| print(f"\nCompiling and benchmarking T={T}...") | ||
| for B in batch_sizes: | ||
| inputs = create_inputs(B, T) | ||
| state = inputs["state"].clone() | ||
|
|
||
| # Warmup / compile | ||
| _ = gated_delta_rule( | ||
| A_log=inputs["A_log"], | ||
| a=inputs["a"], | ||
| dt_bias=inputs["dt_bias"], | ||
| q=inputs["q"], | ||
| k=inputs["k"], | ||
| v=inputs["v"], | ||
| b=inputs["b"], | ||
| initial_state_source=state, | ||
| scale=inputs["scale"], | ||
| ) | ||
|
|
||
| def run_kernel(): | ||
| return gated_delta_rule( | ||
| A_log=inputs["A_log"], | ||
| a=inputs["a"], | ||
| dt_bias=inputs["dt_bias"], | ||
| q=inputs["q"], | ||
| k=inputs["k"], | ||
| v=inputs["v"], | ||
| b=inputs["b"], | ||
| initial_state_source=state, | ||
| scale=inputs["scale"], | ||
| ) | ||
|
|
||
| time_us = benchmark( | ||
| run_kernel, | ||
| num_iterations=num_iterations, | ||
| flush_l2=True, | ||
| use_dummy_matmul=True, | ||
| ) | ||
| results[T][B] = time_us | ||
| print(f" B={B:>3}: {time_us:>7.1f} us") | ||
|
|
||
| # Summary table | ||
| print("\n" + "=" * 70) | ||
| print("SUMMARY: Duration (us) by Batch Size and Sequence Length") | ||
| print("=" * 70) | ||
|
|
||
| # Header | ||
| header = f"{'B':>6} |" | ||
| for T in seqlens: | ||
| header += f" T={T} |" | ||
| print(header) | ||
| print("-" * 70) | ||
|
|
||
| # Data rows | ||
| for B in batch_sizes: | ||
| row = f"{B:>6} |" | ||
| for T in seqlens: | ||
| row += f" {results[T][B]:>7.1f} |" | ||
| print(row) | ||
|
|
||
| print("-" * 70) | ||
|
|
||
| # Averages | ||
| print("\nAverage duration per T:") | ||
| for T in seqlens: | ||
| avg = sum(results[T].values()) / len(results[T]) | ||
| print(f" T={T}: {avg:.1f} us") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We already have benchmarking APIs https://docs.flashinfer.ai/generated/flashinfer.testing.bench_gpu_time.html#flashinfer.testing.bench_gpu_time, please refer to https://github.com/flashinfer-ai/flashinfer/blob/main/benchmarks/bench_gdn_decode.py#L1088-L1095 on how to use these APIs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just curious why that wasn't addressed and pointed in #2370 but was here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The benchmarking script in #2370 should use bench_gpu_time as well. I didn't notice that when reviewing that PR (it's my bad) and it was fixed in #2405.
Let's get it right in one shot in this PR.