diff --git a/flash_sparse_attn/ops/triton/launch_template.py b/flash_sparse_attn/ops/triton/launch_template.py index efb1c9ed..8a6dd6b0 100644 --- a/flash_sparse_attn/ops/triton/launch_template.py +++ b/flash_sparse_attn/ops/triton/launch_template.py @@ -30,10 +30,10 @@ def get_fwd_dense_launch_config( if device.type == "cuda": # If split KV, we set tile_m based on qheads_per_kvhead to ensure good occupancy if is_split_kv: - if pack_gqa and qheads_per_kvhead > 1: + if pack_gqa and qheads_per_kvhead > 16: tile_m = triton.next_power_of_2(qheads_per_kvhead) else: - tile_m = 1 + tile_m = 16 else: # will be set based on architecture and tile_k tile_m = None @@ -63,13 +63,13 @@ def get_fwd_dense_launch_config( elif arch // 10 == 9: if not is_split_kv: if tile_k <= 64: - return (256, 128, 4, 1, 1) - elif tile_k <= 128: return (128, 128, 4, 1, 1) - elif tile_k <= 256: + elif tile_k <= 128: return (128, 64, 4, 1, 1) + elif tile_k <= 256: + return (64, 64, 4, 1, 1) else: - return (128, 64, 4, 1, 1) + return (64, 64, 4, 1, 1) else: if tile_k <= 64: return (tile_m, 256, 4, 1, 1) @@ -141,10 +141,10 @@ def get_fwd_sparse_launch_config( if device.type == "cuda": # If split KV, we set tile_m based on qheads_per_kvhead to ensure good occupancy if is_split_kv: - if pack_gqa and qheads_per_kvhead > 1: + if pack_gqa and qheads_per_kvhead > 16: tile_m = triton.next_power_of_2(qheads_per_kvhead) else: - tile_m = 1 + tile_m = 16 else: # will be set based on architecture and tile_k tile_m = None @@ -174,13 +174,13 @@ def get_fwd_sparse_launch_config( elif arch // 10 == 9: if not is_split_kv: if tile_k <= 64: - return (256, 128, 4, 1, 1) - elif tile_k <= 128: return (128, 128, 4, 1, 1) - elif tile_k <= 256: + elif tile_k <= 128: return (128, 64, 4, 1, 1) + elif tile_k <= 256: + return (64, 64, 4, 1, 1) else: - return (128, 64, 4, 1, 1) + return (64, 64, 4, 1, 1) else: if tile_k <= 64: return (tile_m, 256, 4, 1, 1) @@ -252,10 +252,10 @@ def get_fwd_gated_launch_config( if device.type == "cuda": # If split KV, we set tile_m based on qheads_per_kvhead to ensure good occupancy if is_split_kv: - if pack_gqa and qheads_per_kvhead > 1: + if pack_gqa and qheads_per_kvhead > 16: tile_m = triton.next_power_of_2(qheads_per_kvhead) else: - tile_m = 1 + tile_m = 16 else: # will be set based on architecture and tile_k tile_m = None @@ -285,13 +285,13 @@ def get_fwd_gated_launch_config( elif arch // 10 == 9: if not is_split_kv: if tile_k <= 64: - return (256, 128, 4, 1, 1) - elif tile_k <= 128: return (128, 128, 4, 1, 1) - elif tile_k <= 256: + elif tile_k <= 128: return (128, 64, 4, 1, 1) + elif tile_k <= 256: + return (64, 64, 4, 1, 1) else: - return (128, 64, 4, 1, 1) + return (64, 64, 4, 1, 1) else: if tile_k <= 64: return (tile_m, 256, 4, 1, 1) diff --git a/tests/benchmark_backward.py b/tests/benchmark_backward.py index bfec9d3d..bad5b723 100644 --- a/tests/benchmark_backward.py +++ b/tests/benchmark_backward.py @@ -1,4 +1,5 @@ from typing import List, Optional +import traceback import torch from torch.nn.attention import sdpa_kernel, SDPBackend @@ -269,6 +270,7 @@ def run_benchmark(cfg: BenchmarkConfig) -> BenchmarkResult: cudnn_dense_tflops=cudnn_tflops, ) except Exception as exc: + full_error = f"{exc}\n{traceback.format_exc()}" return BenchmarkResult( config=cfg, triton_dense_ms=None, @@ -281,7 +283,7 @@ def run_benchmark(cfg: BenchmarkConfig) -> BenchmarkResult: triton_gated_tflops=None, fa_dense_tflops=None, cudnn_dense_tflops=None, - error_message=str(exc), + error_message=full_error, ) @@ -290,7 +292,7 @@ def print_results(results: List[BenchmarkResult]) -> None: if not ok: print("No successful benchmark results.") for r in results: - print(f"Failed: {r.config} -> {r.error_message}") + print(f"Failed: {r.config}\n{r.error_message}") return rows = [] diff --git a/tests/benchmark_decode.py b/tests/benchmark_decode.py index 7f4b0a50..8d95a8c0 100644 --- a/tests/benchmark_decode.py +++ b/tests/benchmark_decode.py @@ -1,4 +1,5 @@ from typing import List, Optional +import traceback import torch from torch.nn.attention import sdpa_kernel, SDPBackend @@ -209,6 +210,7 @@ def run_benchmark(cfg: BenchmarkConfig) -> BenchmarkResult: cudnn_dense_tflops=cudnn_dense_tflops, ) except Exception as exc: + full_error = f"{exc}\n{traceback.format_exc()}" return BenchmarkResult( config=cfg, triton_dense_ms=None, @@ -221,7 +223,7 @@ def run_benchmark(cfg: BenchmarkConfig) -> BenchmarkResult: triton_gated_tflops=None, fa_dense_tflops=None, cudnn_dense_tflops=None, - error_message=str(exc), + error_message=full_error, ) @@ -230,7 +232,7 @@ def print_results(results: List[BenchmarkResult]) -> None: if not ok: print("No successful benchmark results.") for r in results: - print(f"Failed: {r.config} -> {r.error_message}") + print(f"Failed: {r.config}\n{r.error_message}") return rows = [] diff --git a/tests/benchmark_forward.py b/tests/benchmark_forward.py index 0c10b897..aabaad0a 100644 --- a/tests/benchmark_forward.py +++ b/tests/benchmark_forward.py @@ -1,4 +1,5 @@ from typing import List, Optional +import traceback import torch from torch.nn.attention import sdpa_kernel, SDPBackend @@ -212,6 +213,7 @@ def run_benchmark(cfg: BenchmarkConfig) -> BenchmarkResult: cudnn_dense_tflops=cudnn_dense_tflops, ) except Exception as exc: + full_error = f"{exc}\n{traceback.format_exc()}" return BenchmarkResult( config=cfg, triton_dense_ms=None, @@ -224,7 +226,7 @@ def run_benchmark(cfg: BenchmarkConfig) -> BenchmarkResult: triton_gated_tflops=None, fa_dense_tflops=None, cudnn_dense_tflops=None, - error_message=str(exc), + error_message=full_error, ) @@ -233,7 +235,7 @@ def print_results(results: List[BenchmarkResult]) -> None: if not ok: print("No successful benchmark results.") for r in results: - print(f"Failed: {r.config} -> {r.error_message}") + print(f"Failed: {r.config}\n{r.error_message}") return rows = []