-
Notifications
You must be signed in to change notification settings - Fork 825
[CuTe DSL] Add modular FMHA prefill attention kernel #2805
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
pgera
wants to merge
14
commits into
flashinfer-ai:main
Choose a base branch
from
pgera:cutedsl-fmha-prefill
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+5,726
−0
Open
Changes from 6 commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
9df6209
[cutedsl] Add modular attention package with FMHA and MLA kernels
pgera e541565
Remove monolithic kernels from original PR, separate benchmarks
pgera 8584f53
Fix causal mask boundary + PV accumulate bugs in modular attention ke…
pgera 77424f3
Fix attention sink precision bugs: wrapper dtype mismatch and M_D_upd…
pgera 9c37501
Fix sliding window mask, sink M_D_update scaling, and add comprehensi…
pgera 67ee085
Remove MLA/decode files from prefill PR
pgera 1cdf377
Replace patch/pipeline.py with upstream cutlass.pipeline
pgera d282b10
Address review feedback: add @cute.jit to get_trip_count, remove unus…
pgera 96eedbb
Simplify RESIDUAL_MASK check and remove redundant qo_head_idx guard (…
pgera 55c718b
Unify variant runtime data into single params mechanism with score_mo…
pgera 80ca7fe
Update copyright year to 2026 for new CuTe DSL attention files
pgera fc784dc
Update copyright year to 2026 for benchmark file
pgera 8241811
Standardize license headers to abbreviated SPDX BSD-3-Clause
pgera 9f0ba5e
Add validation and robustness improvements to CuTe DSL attention (AI-…
pgera File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,174 @@ | ||
| """ | ||
| Copyright (c) 2025 by FlashInfer team. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. | ||
| """ | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
||
| import flashinfer | ||
| from flashinfer.testing.utils import bench_gpu_time | ||
|
|
||
| from flashinfer.cute_dsl.attention import BatchPrefillCuteDSLWrapper | ||
|
|
||
|
|
||
| def bench_fmha_blackwell( | ||
| batch_size, | ||
| qkv_len, | ||
| num_heads, | ||
| head_dim, | ||
| causal, | ||
| dtype, | ||
| ): | ||
| q = torch.randn( | ||
| batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda" | ||
| ) | ||
| k = torch.randn( | ||
| batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda" | ||
| ) | ||
| v = torch.randn( | ||
| batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda" | ||
| ) | ||
|
|
||
| qo_segment_offsets = ( | ||
| torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qkv_len | ||
| ) | ||
| kv_segment_offsets = ( | ||
| torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qkv_len | ||
| ) | ||
| wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper( | ||
| torch.empty(128 * 1024 * 1024, dtype=dtype, device="cuda"), | ||
| kv_layout="NHD", | ||
| backend="cutlass", | ||
| ) | ||
| wrapper.plan( | ||
| qo_segment_offsets, | ||
| kv_segment_offsets, | ||
| num_heads, | ||
| num_heads, | ||
| head_dim, | ||
| head_dim_vo=head_dim, | ||
| causal=causal, | ||
| q_data_type=dtype, | ||
| kv_data_type=dtype, | ||
| ) | ||
| o = wrapper.run(q, k, v) | ||
| measurements = bench_gpu_time( | ||
| lambda: wrapper.run(q, k, v), | ||
| dry_run_time_ms=100, | ||
| repeat_time_ms=1000, | ||
| ) | ||
| ms = np.median(measurements) | ||
|
|
||
| def flops(ms): | ||
| if causal: | ||
| return batch_size * qkv_len * qkv_len * num_heads * head_dim * 2 / ms / 1e9 | ||
| else: | ||
| return batch_size * qkv_len * qkv_len * num_heads * head_dim * 4 / ms / 1e9 | ||
|
|
||
| def io(ms): | ||
| mem_size = ( | ||
| q.numel() * q.element_size() | ||
| + k.numel() * k.element_size() | ||
| + v.numel() * v.element_size() | ||
| + o.numel() * o.element_size() | ||
| ) | ||
| return mem_size / ms / 1e6 | ||
|
|
||
| print( | ||
| f"bench_fmha_blackwell (batch_size={batch_size}, qkv_len={qkv_len}, num_heads={num_heads}, head_dim={head_dim}, causal={causal}), flops: {flops(ms):.3f} TFLOPs/s, io: {io(ms):.3f} GB/s" | ||
| ) | ||
|
|
||
|
|
||
| def bench_fmha_cutedsl( | ||
| batch_size, | ||
| qkv_len, | ||
| num_heads, | ||
| head_dim, | ||
| causal, | ||
| dtype, | ||
| sm_scale=None, | ||
| ): | ||
| if sm_scale is None: | ||
| sm_scale = 1.0 / (head_dim**0.5) | ||
|
|
||
| q = torch.randn( | ||
| batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda" | ||
| ) | ||
| k = torch.randn( | ||
| batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda" | ||
| ) | ||
| v = torch.randn( | ||
| batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda" | ||
| ) | ||
|
|
||
| qo_indptr = ( | ||
| torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qkv_len | ||
| ) | ||
| kv_indptr = ( | ||
| torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qkv_len | ||
| ) | ||
|
|
||
| wrapper = BatchPrefillCuteDSLWrapper( | ||
| torch.empty(128 * 1024 * 1024, device="cuda", dtype=torch.uint8), | ||
| ) | ||
| wrapper.plan( | ||
| qo_indptr, | ||
| kv_indptr, | ||
| num_heads, | ||
| num_heads, | ||
| head_dim, | ||
| head_dim_vo=head_dim, | ||
| causal=causal, | ||
| sm_scale=sm_scale, | ||
| q_data_type=dtype, | ||
| kv_data_type=dtype, | ||
| ) | ||
| o = wrapper.run(q, k, v) | ||
| measurements = bench_gpu_time( | ||
| lambda: wrapper.run(q, k, v), | ||
| dry_run_time_ms=100, | ||
| repeat_time_ms=1000, | ||
| ) | ||
| ms = np.median(measurements) | ||
|
|
||
| def flops(ms): | ||
| if causal: | ||
| return batch_size * qkv_len * qkv_len * num_heads * head_dim * 2 / ms / 1e9 | ||
| else: | ||
| return batch_size * qkv_len * qkv_len * num_heads * head_dim * 4 / ms / 1e9 | ||
|
|
||
| def io(ms): | ||
| mem_size = ( | ||
| q.numel() * q.element_size() | ||
| + k.numel() * k.element_size() | ||
| + v.numel() * v.element_size() | ||
| + o.numel() * o.element_size() | ||
| ) | ||
| return mem_size / ms / 1e6 | ||
|
|
||
| print( | ||
| f"bench_fmha_cutedsl (batch_size={batch_size}, qkv_len={qkv_len}, num_heads={num_heads}, head_dim={head_dim}, causal={causal}), flops: {flops(ms):.3f} TFLOPs/s, io: {io(ms):.3f} GB/s" | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| bench_fmha_cutedsl(128, 512, 32, 128, True, torch.bfloat16) | ||
| bench_fmha_cutedsl(64, 1024, 32, 128, True, torch.bfloat16) | ||
| bench_fmha_cutedsl(32, 2048, 32, 128, True, torch.bfloat16) | ||
| bench_fmha_cutedsl(16, 4096, 32, 128, True, torch.bfloat16) | ||
| bench_fmha_cutedsl(8, 8192, 32, 128, True, torch.bfloat16) | ||
| bench_fmha_cutedsl(4, 16384, 32, 128, True, torch.bfloat16) | ||
| bench_fmha_cutedsl(2, 32768, 32, 128, True, torch.bfloat16) | ||
| bench_fmha_cutedsl(1, 65536, 32, 128, True, torch.bfloat16) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,43 @@ | ||
| # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: BSD-3-Clause | ||
|
|
||
| """Modular attention kernels for CuTe DSL. | ||
|
|
||
| Kernels live at the top level of this package. | ||
| Building blocks (config, tmem_layout, roles, fusion, scheduler, wrappers) are | ||
| one level below in subdirectories. | ||
| """ | ||
|
|
||
| # Kernels | ||
| from .prefill import BlackwellFusedMultiHeadAttentionForward | ||
|
|
||
| # Building blocks | ||
| from .config import AttentionConfig, AttentionFusion, HeadMapping, TileBounds | ||
| from .tmem_layout import TmemLayout | ||
| from .warp_schedule import WarpSchedule, PREFILL_SCHEDULE | ||
| from .pipeline_topology import ( | ||
| PipelineEdge, | ||
| PipelineType, | ||
| PipelineTopology, | ||
| make_prefill_topology, | ||
| ) | ||
| from .mainloop_spec import ( | ||
| MainloopSpec, | ||
| make_prefill_mainloop_spec, | ||
| ) | ||
| from .fusion.mask import MaskType | ||
| from .fusion.logits_transform import sigmoid_logits_transform | ||
| from .fusion.output_transform import dumb_output_transform | ||
| from .scheduler.persistent import ( | ||
| FmhaStaticTileScheduler, | ||
| FmhaStaticTileSchedulerParams, | ||
| create_fmha_static_tile_scheduler, | ||
| create_fmha_static_tile_scheduler_params, | ||
| ) | ||
|
|
||
| # Wrappers | ||
| from .wrappers.batch_prefill import ( | ||
| BatchPrefillCuteDSLWrapper, | ||
| qkv_torch_2_cute, | ||
| create_and_pad_tensor, | ||
| ) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Skip the default sweep on unsupported GPUs.
The
__main__path unconditionally launches an SM100-only kernel. Running this script on another CUDA box will fail in JIT/launch instead of exiting cleanly with a clear message.🤖 Prompt for AI Agents