Skip to content

Commit 792dcb1

Browse files
authored
Tune kernel compilation parameters for #1850 (#1878)
<!-- .github/pull_request_template.md --> ## 📌 Description A follow up to #1850 to adjust pipeline stage / tile size values for perf improvement (benchmarking results below). Also adjust test param to test realistic combinations. ## 🧪 Test results ### Unit testing results (tests/attention/test_blackwell_fmha.py) ``` =============================================================== 3096 passed, 240 skipped in 189.17s (0:03:09) ================================================================ ``` ### Benchmarking results (benchmarks/bench_blackwell_attention.py) (before) ``` === head_dim=64 === bench_fmha_blackwell (batch_size=128, qkv_len=512, num_heads=32, head_dim=64, causal=False), flops: 448.139 TFLOPs/s bench_fmha_blackwell (batch_size=64, qkv_len=1024, num_heads=32, head_dim=64, causal=False), flops: 520.066 TFLOPs/s bench_fmha_blackwell (batch_size=32, qkv_len=2048, num_heads=32, head_dim=64, causal=False), flops: 595.861 TFLOPs/s bench_fmha_blackwell (batch_size=16, qkv_len=4096, num_heads=32, head_dim=64, causal=False), flops: 653.053 TFLOPs/s bench_fmha_blackwell (batch_size=8, qkv_len=8192, num_heads=32, head_dim=64, causal=False), flops: 671.899 TFLOPs/s bench_fmha_blackwell (batch_size=4, qkv_len=16384, num_heads=32, head_dim=64, causal=False), flops: 788.719 TFLOPs/s bench_fmha_blackwell (batch_size=2, qkv_len=32768, num_heads=32, head_dim=64, causal=False), flops: 869.262 TFLOPs/s bench_fmha_blackwell (batch_size=1, qkv_len=65536, num_heads=32, head_dim=64, causal=False), flops: 868.034 TFLOPs/s bench_fmha_blackwell (batch_size=128, qkv_len=512, num_heads=32, head_dim=64, causal=True), flops: 261.792 TFLOPs/s bench_fmha_blackwell (batch_size=64, qkv_len=1024, num_heads=32, head_dim=64, causal=True), flops: 374.697 TFLOPs/s bench_fmha_blackwell (batch_size=32, qkv_len=2048, num_heads=32, head_dim=64, causal=True), flops: 476.372 TFLOPs/s bench_fmha_blackwell (batch_size=16, qkv_len=4096, num_heads=32, head_dim=64, causal=True), flops: 543.667 TFLOPs/s bench_fmha_blackwell (batch_size=8, qkv_len=8192, num_heads=32, head_dim=64, causal=True), flops: 642.878 TFLOPs/s bench_fmha_blackwell (batch_size=4, qkv_len=16384, num_heads=32, head_dim=64, causal=True), flops: 720.390 TFLOPs/s bench_fmha_blackwell (batch_size=2, qkv_len=32768, num_heads=32, head_dim=64, causal=True), flops: 721.056 TFLOPs/s bench_fmha_blackwell (batch_size=1, qkv_len=65536, num_heads=32, head_dim=64, causal=True), flops: 756.090 TFLOPs/s ``` (after) ``` bench_fmha_blackwell (batch_size=128, qkv_len=512, num_heads=32, head_dim=64, causal=False), flops: 695.429 TFLOPs/s bench_fmha_blackwell (batch_size=64, qkv_len=1024, num_heads=32, head_dim=64, causal=False), flops: 876.748 TFLOPs/s bench_fmha_blackwell (batch_size=32, qkv_len=2048, num_heads=32, head_dim=64, causal=False), flops: 985.989 TFLOPs/s bench_fmha_blackwell (batch_size=16, qkv_len=4096, num_heads=32, head_dim=64, causal=False), flops: 1049.088 TFLOPs/s bench_fmha_blackwell (batch_size=8, qkv_len=8192, num_heads=32, head_dim=64, causal=False), flops: 1093.423 TFLOPs/s bench_fmha_blackwell (batch_size=4, qkv_len=16384, num_heads=32, head_dim=64, causal=False), flops: 1119.016 TFLOPs/s bench_fmha_blackwell (batch_size=2, qkv_len=32768, num_heads=32, head_dim=64, causal=False), flops: 1138.080 TFLOPs/s bench_fmha_blackwell (batch_size=1, qkv_len=65536, num_heads=32, head_dim=64, causal=False), flops: 1151.325 TFLOPs/s bench_fmha_blackwell (batch_size=128, qkv_len=512, num_heads=32, head_dim=64, causal=True), flops: 273.278 TFLOPs/s bench_fmha_blackwell (batch_size=64, qkv_len=1024, num_heads=32, head_dim=64, causal=True), flops: 416.845 TFLOPs/s bench_fmha_blackwell (batch_size=32, qkv_len=2048, num_heads=32, head_dim=64, causal=True), flops: 616.595 TFLOPs/s bench_fmha_blackwell (batch_size=16, qkv_len=4096, num_heads=32, head_dim=64, causal=True), flops: 810.543 TFLOPs/s bench_fmha_blackwell (batch_size=8, qkv_len=8192, num_heads=32, head_dim=64, causal=True), flops: 940.429 TFLOPs/s bench_fmha_blackwell (batch_size=4, qkv_len=16384, num_heads=32, head_dim=64, causal=True), flops: 1028.673 TFLOPs/s bench_fmha_blackwell (batch_size=2, qkv_len=32768, num_heads=32, head_dim=64, causal=True), flops: 1083.968 TFLOPs/s bench_fmha_blackwell (batch_size=1, qkv_len=65536, num_heads=32, head_dim=64, causal=True), flops: 1131.110 TFLOPs/s ``` ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 74ee9d0 commit 792dcb1

File tree

3 files changed

+28
-9
lines changed

3 files changed

+28
-9
lines changed

csrc/fmha_cutlass_sm100.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ void FMHACutlassSM100Run(ffi::Tensor workspace_buffer, ffi::Tensor q, ffi::Tenso
102102
using cutlass_type_in = cutlass_dtype_t<DTypeIn>;
103103
using cutlass_type_out = cutlass_dtype_t<DTypeOut>;
104104
using TILE_Q = _256;
105-
using TILE_KV = std::conditional_t<HEAD_DIM_QK == 64, _64, _128>;
105+
using TILE_KV = _128;
106106
using D_QK = cute::Int<HEAD_DIM_QK>;
107107
using D_VO = cute::Int<HEAD_DIM_VO>;
108108
using TileShapeQK = Shape<TILE_Q, TILE_KV, D_QK>;

include/flashinfer/attention/blackwell/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,9 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
6464
using Mask = Mask_;
6565

6666
static constexpr int StageCountQ = 2;
67-
static constexpr int StageCountKV =
68-
get<2>(TileShapeQK{}) == 128 ? 2 : 1; // sizeof(Element_) == 1 ? 2 : 2;
67+
static constexpr int StageCountKV = (get<2>(TileShapeQK{}) == 128 || get<2>(TileShapeQK{}) == 64)
68+
? 2
69+
: 1; // sizeof(Element_) == 1 ? 2 : 2;
6970

7071
using StagesQ = cutlass::gemm::collective::StageCount<StageCountQ>;
7172
using StagesKV = cutlass::gemm::collective::StageCount<StageCountKV>;

tests/attention/test_blackwell_fmha.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,14 @@ def attention_varlen_ref(
9090
@pytest.mark.parametrize("kv_len", [1, 17, 544, 977, 1999])
9191
@pytest.mark.parametrize("num_qo_heads", [32])
9292
@pytest.mark.parametrize("num_kv_heads", [8, 32])
93-
@pytest.mark.parametrize("head_dim_qk,head_dim_vo", [(192, 128), (128, 128), (64, 64)])
94-
@pytest.mark.parametrize("sm_scale", [1.0, 1.0 / math.sqrt(192), 1.0 / math.sqrt(128)])
93+
@pytest.mark.parametrize(
94+
"head_dim_qk,head_dim_vo,sm_scale",
95+
[
96+
(192, 128, 1.0 / math.sqrt(192)),
97+
(128, 128, 1.0 / math.sqrt(128)),
98+
(64, 64, 1.0 / math.sqrt(64)),
99+
],
100+
)
95101
@pytest.mark.parametrize("causal", [False, True])
96102
@pytest.mark.parametrize("dtype", [torch.bfloat16])
97103
def test_blackwell_cutlass_fmha(
@@ -168,8 +174,14 @@ def test_blackwell_cutlass_fmha(
168174
@pytest.mark.parametrize("indptr", VARLEN_INDPTR_PARAMS)
169175
@pytest.mark.parametrize("num_qo_heads", [32])
170176
@pytest.mark.parametrize("num_kv_heads", [8, 32])
171-
@pytest.mark.parametrize("head_dim_qk,head_dim_vo", [(192, 128), (128, 128), (64, 64)])
172-
@pytest.mark.parametrize("sm_scale", [1.0 / math.sqrt(128)])
177+
@pytest.mark.parametrize(
178+
"head_dim_qk,head_dim_vo,sm_scale",
179+
[
180+
(192, 128, 1.0 / math.sqrt(192)),
181+
(128, 128, 1.0 / math.sqrt(128)),
182+
(64, 64, 1.0 / math.sqrt(64)),
183+
],
184+
)
173185
@pytest.mark.parametrize("causal", [False, True])
174186
@pytest.mark.parametrize("dtype", [torch.bfloat16])
175187
def test_blackwell_cutlass_varlen(
@@ -249,8 +261,14 @@ def test_blackwell_cutlass_varlen(
249261
@pytest.mark.parametrize("kv_indptr_list", [[0, 50, 50, 50, 50, 50, 50, 50]])
250262
@pytest.mark.parametrize("num_qo_heads", [32])
251263
@pytest.mark.parametrize("num_kv_heads", [8, 32])
252-
@pytest.mark.parametrize("head_dim_qk,head_dim_vo", [(192, 128), (128, 128), (64, 64)])
253-
@pytest.mark.parametrize("sm_scale", [1.0 / math.sqrt(128)])
264+
@pytest.mark.parametrize(
265+
"head_dim_qk,head_dim_vo,sm_scale",
266+
[
267+
(192, 128, 1.0 / math.sqrt(192)),
268+
(128, 128, 1.0 / math.sqrt(128)),
269+
(64, 64, 1.0 / math.sqrt(64)),
270+
],
271+
)
254272
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
255273
def test_blackwell_cutlass_qo_kv_varlen(
256274
qo_indptr_list,

0 commit comments

Comments
 (0)