Skip to content

Commit f3ea938

Browse files
authored
Add head_dim=64 for blackwell cutlass fmha implementation (#1850)
<!-- .github/pull_request_template.md --> ## 📌 Description This PR adds support for head_dim=64 for blackwell cutlass fmha. This expanded support is added to the unit tests and benchmarking scripts. The benchmarking script was used to check for the optimal stageCountKV (hypothesized that it may be 3 for 64; ended up being 1). ## 🧪 Test Results (on B300, CUDA 13.0) ```pytest tests/attention/test_blackwell_fmha.py ``` ```=============================================================== 5616 passed, 720 skipped in 203.34s (0:03:23) ================================================================``` ```python benchmarks/bench_blackwell_attention.py``` ``` === head_dim=128 === bench_fmha_blackwell (batch_size=128, qkv_len=512, num_heads=32, head_dim=128, causal=False), flops: 1024.563 TFLOPs/s bench_fmha_blackwell (batch_size=64, qkv_len=1024, num_heads=32, head_dim=128, causal=False), flops: 1234.186 TFLOPs/s bench_fmha_blackwell (batch_size=32, qkv_len=2048, num_heads=32, head_dim=128, causal=False), flops: 1386.312 TFLOPs/s bench_fmha_blackwell (batch_size=16, qkv_len=4096, num_heads=32, head_dim=128, causal=False), flops: 1496.488 TFLOPs/s bench_fmha_blackwell (batch_size=8, qkv_len=8192, num_heads=32, head_dim=128, causal=False), flops: 1540.769 TFLOPs/s bench_fmha_blackwell (batch_size=4, qkv_len=16384, num_heads=32, head_dim=128, causal=False), flops: 1605.068 TFLOPs/s bench_fmha_blackwell (batch_size=2, qkv_len=32768, num_heads=32, head_dim=128, causal=False), flops: 1648.648 TFLOPs/s bench_fmha_blackwell (batch_size=1, qkv_len=65536, num_heads=32, head_dim=128, causal=False), flops: 1658.047 TFLOPs/s bench_fmha_blackwell (batch_size=128, qkv_len=512, num_heads=32, head_dim=128, causal=True), flops: 440.781 TFLOPs/s bench_fmha_blackwell (batch_size=64, qkv_len=1024, num_heads=32, head_dim=128, causal=True), flops: 638.431 TFLOPs/s bench_fmha_blackwell (batch_size=32, qkv_len=2048, num_heads=32, head_dim=128, causal=True), flops: 963.078 TFLOPs/s bench_fmha_blackwell (batch_size=16, qkv_len=4096, num_heads=32, head_dim=128, causal=True), flops: 1223.670 TFLOPs/s bench_fmha_blackwell (batch_size=8, qkv_len=8192, num_heads=32, head_dim=128, causal=True), flops: 1379.715 TFLOPs/s bench_fmha_blackwell (batch_size=4, qkv_len=16384, num_heads=32, head_dim=128, causal=True), flops: 1497.805 TFLOPs/s bench_fmha_blackwell (batch_size=2, qkv_len=32768, num_heads=32, head_dim=128, causal=True), flops: 1584.493 TFLOPs/s bench_fmha_blackwell (batch_size=1, qkv_len=65536, num_heads=32, head_dim=128, causal=True), flops: 1638.206 TFLOPs/s === head_dim=64 === bench_fmha_blackwell (batch_size=128, qkv_len=512, num_heads=32, head_dim=64, causal=False), flops: 449.641 TFLOPs/s bench_fmha_blackwell (batch_size=64, qkv_len=1024, num_heads=32, head_dim=64, causal=False), flops: 520.870 TFLOPs/s bench_fmha_blackwell (batch_size=32, qkv_len=2048, num_heads=32, head_dim=64, causal=False), flops: 596.860 TFLOPs/s bench_fmha_blackwell (batch_size=16, qkv_len=4096, num_heads=32, head_dim=64, causal=False), flops: 654.122 TFLOPs/s bench_fmha_blackwell (batch_size=8, qkv_len=8192, num_heads=32, head_dim=64, causal=False), flops: 673.011 TFLOPs/s bench_fmha_blackwell (batch_size=4, qkv_len=16384, num_heads=32, head_dim=64, causal=False), flops: 791.186 TFLOPs/s bench_fmha_blackwell (batch_size=2, qkv_len=32768, num_heads=32, head_dim=64, causal=False), flops: 872.266 TFLOPs/s bench_fmha_blackwell (batch_size=1, qkv_len=65536, num_heads=32, head_dim=64, causal=False), flops: 870.826 TFLOPs/s bench_fmha_blackwell (batch_size=128, qkv_len=512, num_heads=32, head_dim=64, causal=True), flops: 262.144 TFLOPs/s bench_fmha_blackwell (batch_size=64, qkv_len=1024, num_heads=32, head_dim=64, causal=True), flops: 375.960 TFLOPs/s bench_fmha_blackwell (batch_size=32, qkv_len=2048, num_heads=32, head_dim=64, causal=True), flops: 477.245 TFLOPs/s bench_fmha_blackwell (batch_size=16, qkv_len=4096, num_heads=32, head_dim=64, causal=True), flops: 544.132 TFLOPs/s bench_fmha_blackwell (batch_size=8, qkv_len=8192, num_heads=32, head_dim=64, causal=True), flops: 644.116 TFLOPs/s bench_fmha_blackwell (batch_size=4, qkv_len=16384, num_heads=32, head_dim=64, causal=True), flops: 721.476 TFLOPs/s bench_fmha_blackwell (batch_size=2, qkv_len=32768, num_heads=32, head_dim=64, causal=True), flops: 723.058 TFLOPs/s bench_fmha_blackwell (batch_size=1, qkv_len=65536, num_heads=32, head_dim=64, causal=True), flops: 758.397 TFLOPs/s ``` ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent ec4fc2c commit f3ea938

File tree

5 files changed

+33
-12
lines changed

5 files changed

+33
-12
lines changed

benchmarks/bench_blackwell_attention.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def flops(ms):
8181

8282

8383
if __name__ == "__main__":
84+
print("\n === head_dim=128 ===")
8485
bench_fmha_blackwell(128, 512, 32, 128, False, torch.bfloat16)
8586
bench_fmha_blackwell(64, 1024, 32, 128, False, torch.bfloat16)
8687
bench_fmha_blackwell(32, 2048, 32, 128, False, torch.bfloat16)
@@ -98,3 +99,22 @@ def flops(ms):
9899
bench_fmha_blackwell(4, 16384, 32, 128, True, torch.bfloat16)
99100
bench_fmha_blackwell(2, 32768, 32, 128, True, torch.bfloat16)
100101
bench_fmha_blackwell(1, 65536, 32, 128, True, torch.bfloat16)
102+
103+
print("\n === head_dim=64 ===")
104+
bench_fmha_blackwell(128, 512, 32, 64, False, torch.bfloat16)
105+
bench_fmha_blackwell(64, 1024, 32, 64, False, torch.bfloat16)
106+
bench_fmha_blackwell(32, 2048, 32, 64, False, torch.bfloat16)
107+
bench_fmha_blackwell(16, 4096, 32, 64, False, torch.bfloat16)
108+
bench_fmha_blackwell(8, 8192, 32, 64, False, torch.bfloat16)
109+
bench_fmha_blackwell(4, 16384, 32, 64, False, torch.bfloat16)
110+
bench_fmha_blackwell(2, 32768, 32, 64, False, torch.bfloat16)
111+
bench_fmha_blackwell(1, 65536, 32, 64, False, torch.bfloat16)
112+
113+
bench_fmha_blackwell(128, 512, 32, 64, True, torch.bfloat16)
114+
bench_fmha_blackwell(64, 1024, 32, 64, True, torch.bfloat16)
115+
bench_fmha_blackwell(32, 2048, 32, 64, True, torch.bfloat16)
116+
bench_fmha_blackwell(16, 4096, 32, 64, True, torch.bfloat16)
117+
bench_fmha_blackwell(8, 8192, 32, 64, True, torch.bfloat16)
118+
bench_fmha_blackwell(4, 16384, 32, 64, True, torch.bfloat16)
119+
bench_fmha_blackwell(2, 32768, 32, 64, True, torch.bfloat16)
120+
bench_fmha_blackwell(1, 65536, 32, 64, True, torch.bfloat16)

csrc/fmha_cutlass_sm100.cu

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ using tvm::ffi::Optional;
4343
constexpr int HEAD_DIM_QK = 128; \
4444
constexpr int HEAD_DIM_VO = 128; \
4545
return __VA_ARGS__(); \
46+
} else if (head_dim_qk == 64 && head_dim_vo == 64) { \
47+
constexpr int HEAD_DIM_QK = 64; \
48+
constexpr int HEAD_DIM_VO = 64; \
49+
return __VA_ARGS__(); \
4650
} \
4751
return false; \
4852
}()
@@ -98,7 +102,7 @@ void FMHACutlassSM100Run(ffi::Tensor workspace_buffer, ffi::Tensor q, ffi::Tenso
98102
using cutlass_type_in = cutlass_dtype_t<DTypeIn>;
99103
using cutlass_type_out = cutlass_dtype_t<DTypeOut>;
100104
using TILE_Q = _256;
101-
using TILE_KV = _128;
105+
using TILE_KV = std::conditional_t<HEAD_DIM_QK == 64, _64, _128>;
102106
using D_QK = cute::Int<HEAD_DIM_QK>;
103107
using D_VO = cute::Int<HEAD_DIM_VO>;
104108
using TileShapeQK = Shape<TILE_Q, TILE_KV, D_QK>;

include/flashinfer/attention/blackwell/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -857,8 +857,8 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
857857

858858
float2 scale_f32x2 = make_float2(scale, scale);
859859

860-
Tensor tTMrO =
861-
make_tensor<ElementPV>(make_shape(shape(tTMEM_LOADcO), Int<128 / kCorrectionTileSize>{}));
860+
Tensor tTMrO = make_tensor<ElementPV>(
861+
make_shape(shape(tTMEM_LOADcO), Int<get<1>(TileShapePV{}) / kCorrectionTileSize>{}));
862862

863863
auto copy_in = [&](int i) {
864864
Tensor tTMEM_LOADtO_i = tTMEM_LOADtO;

include/flashinfer/attention/blackwell/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -785,7 +785,7 @@ struct Sm100FmhaGenMainloopWarpspecialized {
785785
// loop:
786786
// TMEM_LOAD, TMEM_LOAD, FMUL2, FFMA2, STG
787787
CUTLASS_PRAGMA_UNROLL
788-
for (int i = 0; i < 128 / kCorrectionTileSize; i++) {
788+
for (int i = 0; i < get<2>(TileShape{}) / kCorrectionTileSize; i++) {
789789
Tensor tTMEM_LOADtO0_i = tTMEM_LOADtO0;
790790
tTMEM_LOADtO0_i.data() = tTMEM_LOADtO0_i.data().get() + uint32_t(i * kCorrectionTileSize);
791791
Tensor tTMEM_LOADtO1_i = tTMEM_LOADtO1;
@@ -867,8 +867,8 @@ struct Sm100FmhaGenMainloopWarpspecialized {
867867

868868
float2 scale_f32x2 = make_float2(scale, scale);
869869

870-
Tensor tTMrO =
871-
make_tensor<ElementPV>(make_shape(shape(tTMEM_LOADcO), Int<128 / kCorrectionTileSize>{}));
870+
Tensor tTMrO = make_tensor<ElementPV>(
871+
make_shape(shape(tTMEM_LOADcO), Int<get<2>(TileShape{}) / kCorrectionTileSize>{}));
872872

873873
auto copy_in = [&](int i) {
874874
Tensor tTMEM_LOADtO_i = tTMEM_LOADtO;

tests/attention/test_blackwell_fmha.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,7 @@ def attention_varlen_ref(
9090
@pytest.mark.parametrize("kv_len", [1, 17, 544, 977, 1999])
9191
@pytest.mark.parametrize("num_qo_heads", [32])
9292
@pytest.mark.parametrize("num_kv_heads", [8, 32])
93-
@pytest.mark.parametrize("head_dim_qk", [192, 128])
94-
@pytest.mark.parametrize("head_dim_vo", [128])
93+
@pytest.mark.parametrize("head_dim_qk,head_dim_vo", [(192, 128), (128, 128), (64, 64)])
9594
@pytest.mark.parametrize("sm_scale", [1.0, 1.0 / math.sqrt(192), 1.0 / math.sqrt(128)])
9695
@pytest.mark.parametrize("causal", [False, True])
9796
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@@ -169,8 +168,7 @@ def test_blackwell_cutlass_fmha(
169168
@pytest.mark.parametrize("indptr", VARLEN_INDPTR_PARAMS)
170169
@pytest.mark.parametrize("num_qo_heads", [32])
171170
@pytest.mark.parametrize("num_kv_heads", [8, 32])
172-
@pytest.mark.parametrize("head_dim_qk", [192, 128])
173-
@pytest.mark.parametrize("head_dim_vo", [128])
171+
@pytest.mark.parametrize("head_dim_qk,head_dim_vo", [(192, 128), (128, 128), (64, 64)])
174172
@pytest.mark.parametrize("sm_scale", [1.0 / math.sqrt(128)])
175173
@pytest.mark.parametrize("causal", [False, True])
176174
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@@ -251,8 +249,7 @@ def test_blackwell_cutlass_varlen(
251249
@pytest.mark.parametrize("kv_indptr_list", [[0, 50, 50, 50, 50, 50, 50, 50]])
252250
@pytest.mark.parametrize("num_qo_heads", [32])
253251
@pytest.mark.parametrize("num_kv_heads", [8, 32])
254-
@pytest.mark.parametrize("head_dim_qk", [192, 128])
255-
@pytest.mark.parametrize("head_dim_vo", [128])
252+
@pytest.mark.parametrize("head_dim_qk,head_dim_vo", [(192, 128), (128, 128), (64, 64)])
256253
@pytest.mark.parametrize("sm_scale", [1.0 / math.sqrt(128)])
257254
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
258255
def test_blackwell_cutlass_qo_kv_varlen(

0 commit comments

Comments
 (0)