Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 174 additions & 0 deletions benchmarks/bench_blackwell_attention_cutedsl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
"""
Copyright (c) 2025 by FlashInfer team.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import numpy as np
import torch

import flashinfer
from flashinfer.testing.utils import bench_gpu_time

from flashinfer.cute_dsl.attention import BatchPrefillCuteDSLWrapper


def bench_fmha_blackwell(
batch_size,
qkv_len,
num_heads,
head_dim,
causal,
dtype,
):
q = torch.randn(
batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda"
)
k = torch.randn(
batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda"
)
v = torch.randn(
batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda"
)

qo_segment_offsets = (
torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qkv_len
)
kv_segment_offsets = (
torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qkv_len
)
wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper(
torch.empty(128 * 1024 * 1024, dtype=dtype, device="cuda"),
kv_layout="NHD",
backend="cutlass",
)
wrapper.plan(
qo_segment_offsets,
kv_segment_offsets,
num_heads,
num_heads,
head_dim,
head_dim_vo=head_dim,
causal=causal,
q_data_type=dtype,
kv_data_type=dtype,
)
o = wrapper.run(q, k, v)
measurements = bench_gpu_time(
lambda: wrapper.run(q, k, v),
dry_run_time_ms=100,
repeat_time_ms=1000,
)
ms = np.median(measurements)

def flops(ms):
if causal:
return batch_size * qkv_len * qkv_len * num_heads * head_dim * 2 / ms / 1e9
else:
return batch_size * qkv_len * qkv_len * num_heads * head_dim * 4 / ms / 1e9

def io(ms):
mem_size = (
q.numel() * q.element_size()
+ k.numel() * k.element_size()
+ v.numel() * v.element_size()
+ o.numel() * o.element_size()
)
return mem_size / ms / 1e6

print(
f"bench_fmha_blackwell (batch_size={batch_size}, qkv_len={qkv_len}, num_heads={num_heads}, head_dim={head_dim}, causal={causal}), flops: {flops(ms):.3f} TFLOPs/s, io: {io(ms):.3f} GB/s"
)


def bench_fmha_cutedsl(
batch_size,
qkv_len,
num_heads,
head_dim,
causal,
dtype,
sm_scale=None,
):
if sm_scale is None:
sm_scale = 1.0 / (head_dim**0.5)

q = torch.randn(
batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda"
)
k = torch.randn(
batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda"
)
v = torch.randn(
batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda"
)

qo_indptr = (
torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qkv_len
)
kv_indptr = (
torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qkv_len
)

wrapper = BatchPrefillCuteDSLWrapper(
torch.empty(128 * 1024 * 1024, device="cuda", dtype=torch.uint8),
)
wrapper.plan(
qo_indptr,
kv_indptr,
num_heads,
num_heads,
head_dim,
head_dim_vo=head_dim,
causal=causal,
sm_scale=sm_scale,
q_data_type=dtype,
kv_data_type=dtype,
)
o = wrapper.run(q, k, v)
measurements = bench_gpu_time(
lambda: wrapper.run(q, k, v),
dry_run_time_ms=100,
repeat_time_ms=1000,
)
ms = np.median(measurements)

def flops(ms):
if causal:
return batch_size * qkv_len * qkv_len * num_heads * head_dim * 2 / ms / 1e9
else:
return batch_size * qkv_len * qkv_len * num_heads * head_dim * 4 / ms / 1e9

def io(ms):
mem_size = (
q.numel() * q.element_size()
+ k.numel() * k.element_size()
+ v.numel() * v.element_size()
+ o.numel() * o.element_size()
)
return mem_size / ms / 1e6

print(
f"bench_fmha_cutedsl (batch_size={batch_size}, qkv_len={qkv_len}, num_heads={num_heads}, head_dim={head_dim}, causal={causal}), flops: {flops(ms):.3f} TFLOPs/s, io: {io(ms):.3f} GB/s"
)


if __name__ == "__main__":
bench_fmha_cutedsl(128, 512, 32, 128, True, torch.bfloat16)
bench_fmha_cutedsl(64, 1024, 32, 128, True, torch.bfloat16)
bench_fmha_cutedsl(32, 2048, 32, 128, True, torch.bfloat16)
bench_fmha_cutedsl(16, 4096, 32, 128, True, torch.bfloat16)
bench_fmha_cutedsl(8, 8192, 32, 128, True, torch.bfloat16)
bench_fmha_cutedsl(4, 16384, 32, 128, True, torch.bfloat16)
bench_fmha_cutedsl(2, 32768, 32, 128, True, torch.bfloat16)
bench_fmha_cutedsl(1, 65536, 32, 128, True, torch.bfloat16)
Comment on lines +153 to +161
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Skip the default sweep on unsupported GPUs.

The __main__ path unconditionally launches an SM100-only kernel. Running this script on another CUDA box will fail in JIT/launch instead of exiting cleanly with a clear message.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_blackwell_attention_cutedsl.py` around lines 153 - 161, The
script currently unconditionally runs an SM100-only kernel in the __main__ block
(calls to bench_fmha_cutedsl), which will JIT/launch-fail on non-SM100 GPUs; add
a GPU capability check before running the default sweep: use
torch.cuda.is_available() and torch.cuda.get_device_capability() or
torch.cuda.get_device_properties(device).major/minor (or device name) to detect
whether the current GPU supports SM100, and if not, skip the default
bench_fmha_cutedsl(...) calls and exit or print a clear message; update the
__main__ section so the SM100-only sweep only runs when the capability check
passes.

43 changes: 43 additions & 0 deletions flashinfer/cute_dsl/attention/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

"""Modular attention kernels for CuTe DSL.

Kernels live at the top level of this package.
Building blocks (config, tmem_layout, roles, fusion, scheduler, wrappers) are
one level below in subdirectories.
"""

# Kernels
from .prefill import BlackwellFusedMultiHeadAttentionForward

# Building blocks
from .config import AttentionConfig, AttentionFusion, HeadMapping, TileBounds
from .tmem_layout import TmemLayout
from .warp_schedule import WarpSchedule, PREFILL_SCHEDULE
from .pipeline_topology import (
PipelineEdge,
PipelineType,
PipelineTopology,
make_prefill_topology,
)
from .mainloop_spec import (
MainloopSpec,
make_prefill_mainloop_spec,
)
from .fusion.mask import MaskType
from .fusion.logits_transform import sigmoid_logits_transform
from .fusion.output_transform import dumb_output_transform
from .scheduler.persistent import (
FmhaStaticTileScheduler,
FmhaStaticTileSchedulerParams,
create_fmha_static_tile_scheduler,
create_fmha_static_tile_scheduler_params,
)

# Wrappers
from .wrappers.batch_prefill import (
BatchPrefillCuteDSLWrapper,
qkv_torch_2_cute,
create_and_pad_tensor,
)
Loading
Loading