Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions flash_sparse_attn/ops/triton/launch_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ def get_fwd_dense_launch_config(
if device.type == "cuda":
# If split KV, we set tile_m based on qheads_per_kvhead to ensure good occupancy
if is_split_kv:
if pack_gqa and qheads_per_kvhead > 1:
if pack_gqa and qheads_per_kvhead > 16:
tile_m = triton.next_power_of_2(qheads_per_kvhead)
else:
tile_m = 1
tile_m = 16
Comment on lines 32 to +36
else:
# will be set based on architecture and tile_k
tile_m = None
Expand Down Expand Up @@ -63,13 +63,13 @@ def get_fwd_dense_launch_config(
elif arch // 10 == 9:
if not is_split_kv:
if tile_k <= 64:
return (256, 128, 4, 1, 1)
elif tile_k <= 128:
return (128, 128, 4, 1, 1)
elif tile_k <= 256:
elif tile_k <= 128:
return (128, 64, 4, 1, 1)
elif tile_k <= 256:
return (64, 64, 4, 1, 1)
else:
return (128, 64, 4, 1, 1)
return (64, 64, 4, 1, 1)
else:
if tile_k <= 64:
return (tile_m, 256, 4, 1, 1)
Expand Down Expand Up @@ -141,10 +141,10 @@ def get_fwd_sparse_launch_config(
if device.type == "cuda":
# If split KV, we set tile_m based on qheads_per_kvhead to ensure good occupancy
if is_split_kv:
if pack_gqa and qheads_per_kvhead > 1:
if pack_gqa and qheads_per_kvhead > 16:
tile_m = triton.next_power_of_2(qheads_per_kvhead)
else:
tile_m = 1
tile_m = 16
else:
# will be set based on architecture and tile_k
tile_m = None
Expand Down Expand Up @@ -174,13 +174,13 @@ def get_fwd_sparse_launch_config(
elif arch // 10 == 9:
if not is_split_kv:
if tile_k <= 64:
return (256, 128, 4, 1, 1)
elif tile_k <= 128:
return (128, 128, 4, 1, 1)
elif tile_k <= 256:
elif tile_k <= 128:
return (128, 64, 4, 1, 1)
elif tile_k <= 256:
return (64, 64, 4, 1, 1)
else:
return (128, 64, 4, 1, 1)
return (64, 64, 4, 1, 1)
else:
if tile_k <= 64:
return (tile_m, 256, 4, 1, 1)
Expand Down Expand Up @@ -252,10 +252,10 @@ def get_fwd_gated_launch_config(
if device.type == "cuda":
# If split KV, we set tile_m based on qheads_per_kvhead to ensure good occupancy
if is_split_kv:
if pack_gqa and qheads_per_kvhead > 1:
if pack_gqa and qheads_per_kvhead > 16:
tile_m = triton.next_power_of_2(qheads_per_kvhead)
else:
tile_m = 1
tile_m = 16
Comment on lines 254 to +258
else:
# will be set based on architecture and tile_k
tile_m = None
Expand Down Expand Up @@ -285,13 +285,13 @@ def get_fwd_gated_launch_config(
elif arch // 10 == 9:
if not is_split_kv:
if tile_k <= 64:
return (256, 128, 4, 1, 1)
elif tile_k <= 128:
return (128, 128, 4, 1, 1)
elif tile_k <= 256:
elif tile_k <= 128:
return (128, 64, 4, 1, 1)
elif tile_k <= 256:
return (64, 64, 4, 1, 1)
else:
return (128, 64, 4, 1, 1)
return (64, 64, 4, 1, 1)
else:
if tile_k <= 64:
return (tile_m, 256, 4, 1, 1)
Expand Down
6 changes: 4 additions & 2 deletions tests/benchmark_backward.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List, Optional
import traceback

import torch
from torch.nn.attention import sdpa_kernel, SDPBackend
Expand Down Expand Up @@ -269,6 +270,7 @@ def run_benchmark(cfg: BenchmarkConfig) -> BenchmarkResult:
cudnn_dense_tflops=cudnn_tflops,
)
except Exception as exc:
full_error = f"{exc}\n{traceback.format_exc()}"
return BenchmarkResult(
config=cfg,
triton_dense_ms=None,
Expand All @@ -281,7 +283,7 @@ def run_benchmark(cfg: BenchmarkConfig) -> BenchmarkResult:
triton_gated_tflops=None,
fa_dense_tflops=None,
cudnn_dense_tflops=None,
error_message=str(exc),
error_message=full_error,
)
Comment on lines 272 to 287


Expand All @@ -290,7 +292,7 @@ def print_results(results: List[BenchmarkResult]) -> None:
if not ok:
print("No successful benchmark results.")
for r in results:
print(f"Failed: {r.config} -> {r.error_message}")
print(f"Failed: {r.config}\n{r.error_message}")
return

rows = []
Expand Down
6 changes: 4 additions & 2 deletions tests/benchmark_decode.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List, Optional
import traceback

import torch
from torch.nn.attention import sdpa_kernel, SDPBackend
Expand Down Expand Up @@ -209,6 +210,7 @@ def run_benchmark(cfg: BenchmarkConfig) -> BenchmarkResult:
cudnn_dense_tflops=cudnn_dense_tflops,
)
except Exception as exc:
full_error = f"{exc}\n{traceback.format_exc()}"
return BenchmarkResult(
config=cfg,
triton_dense_ms=None,
Expand All @@ -221,7 +223,7 @@ def run_benchmark(cfg: BenchmarkConfig) -> BenchmarkResult:
triton_gated_tflops=None,
fa_dense_tflops=None,
cudnn_dense_tflops=None,
error_message=str(exc),
error_message=full_error,
Comment on lines 212 to +226
)


Expand All @@ -230,7 +232,7 @@ def print_results(results: List[BenchmarkResult]) -> None:
if not ok:
print("No successful benchmark results.")
for r in results:
print(f"Failed: {r.config} -> {r.error_message}")
print(f"Failed: {r.config}\n{r.error_message}")
return

rows = []
Expand Down
6 changes: 4 additions & 2 deletions tests/benchmark_forward.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List, Optional
import traceback

import torch
from torch.nn.attention import sdpa_kernel, SDPBackend
Expand Down Expand Up @@ -212,6 +213,7 @@ def run_benchmark(cfg: BenchmarkConfig) -> BenchmarkResult:
cudnn_dense_tflops=cudnn_dense_tflops,
)
except Exception as exc:
full_error = f"{exc}\n{traceback.format_exc()}"
return BenchmarkResult(
config=cfg,
triton_dense_ms=None,
Expand All @@ -224,7 +226,7 @@ def run_benchmark(cfg: BenchmarkConfig) -> BenchmarkResult:
triton_gated_tflops=None,
fa_dense_tflops=None,
cudnn_dense_tflops=None,
error_message=str(exc),
error_message=full_error,
Comment on lines 215 to +229
)


Expand All @@ -233,7 +235,7 @@ def print_results(results: List[BenchmarkResult]) -> None:
if not ok:
print("No successful benchmark results.")
for r in results:
print(f"Failed: {r.config} -> {r.error_message}")
print(f"Failed: {r.config}\n{r.error_message}")
return

rows = []
Expand Down
Loading