Skip to content

Commit 26fe0af

Browse files
authored
Merge pull request #251 from HKUSTDial/optim_triton_version
[BUG FIX] Improve error reporting and occupancy in benchmarks
2 parents f5befe0 + 0424884 commit 26fe0af

File tree

4 files changed

+30
-24
lines changed

4 files changed

+30
-24
lines changed

flash_sparse_attn/ops/triton/launch_template.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@ def get_fwd_dense_launch_config(
3030
if device.type == "cuda":
3131
# If split KV, we set tile_m based on qheads_per_kvhead to ensure good occupancy
3232
if is_split_kv:
33-
if pack_gqa and qheads_per_kvhead > 1:
33+
if pack_gqa and qheads_per_kvhead > 16:
3434
tile_m = triton.next_power_of_2(qheads_per_kvhead)
3535
else:
36-
tile_m = 1
36+
tile_m = 16
3737
else:
3838
# will be set based on architecture and tile_k
3939
tile_m = None
@@ -63,13 +63,13 @@ def get_fwd_dense_launch_config(
6363
elif arch // 10 == 9:
6464
if not is_split_kv:
6565
if tile_k <= 64:
66-
return (256, 128, 4, 1, 1)
67-
elif tile_k <= 128:
6866
return (128, 128, 4, 1, 1)
69-
elif tile_k <= 256:
67+
elif tile_k <= 128:
7068
return (128, 64, 4, 1, 1)
69+
elif tile_k <= 256:
70+
return (64, 64, 4, 1, 1)
7171
else:
72-
return (128, 64, 4, 1, 1)
72+
return (64, 64, 4, 1, 1)
7373
else:
7474
if tile_k <= 64:
7575
return (tile_m, 256, 4, 1, 1)
@@ -141,10 +141,10 @@ def get_fwd_sparse_launch_config(
141141
if device.type == "cuda":
142142
# If split KV, we set tile_m based on qheads_per_kvhead to ensure good occupancy
143143
if is_split_kv:
144-
if pack_gqa and qheads_per_kvhead > 1:
144+
if pack_gqa and qheads_per_kvhead > 16:
145145
tile_m = triton.next_power_of_2(qheads_per_kvhead)
146146
else:
147-
tile_m = 1
147+
tile_m = 16
148148
else:
149149
# will be set based on architecture and tile_k
150150
tile_m = None
@@ -174,13 +174,13 @@ def get_fwd_sparse_launch_config(
174174
elif arch // 10 == 9:
175175
if not is_split_kv:
176176
if tile_k <= 64:
177-
return (256, 128, 4, 1, 1)
178-
elif tile_k <= 128:
179177
return (128, 128, 4, 1, 1)
180-
elif tile_k <= 256:
178+
elif tile_k <= 128:
181179
return (128, 64, 4, 1, 1)
180+
elif tile_k <= 256:
181+
return (64, 64, 4, 1, 1)
182182
else:
183-
return (128, 64, 4, 1, 1)
183+
return (64, 64, 4, 1, 1)
184184
else:
185185
if tile_k <= 64:
186186
return (tile_m, 256, 4, 1, 1)
@@ -252,10 +252,10 @@ def get_fwd_gated_launch_config(
252252
if device.type == "cuda":
253253
# If split KV, we set tile_m based on qheads_per_kvhead to ensure good occupancy
254254
if is_split_kv:
255-
if pack_gqa and qheads_per_kvhead > 1:
255+
if pack_gqa and qheads_per_kvhead > 16:
256256
tile_m = triton.next_power_of_2(qheads_per_kvhead)
257257
else:
258-
tile_m = 1
258+
tile_m = 16
259259
else:
260260
# will be set based on architecture and tile_k
261261
tile_m = None
@@ -285,13 +285,13 @@ def get_fwd_gated_launch_config(
285285
elif arch // 10 == 9:
286286
if not is_split_kv:
287287
if tile_k <= 64:
288-
return (256, 128, 4, 1, 1)
289-
elif tile_k <= 128:
290288
return (128, 128, 4, 1, 1)
291-
elif tile_k <= 256:
289+
elif tile_k <= 128:
292290
return (128, 64, 4, 1, 1)
291+
elif tile_k <= 256:
292+
return (64, 64, 4, 1, 1)
293293
else:
294-
return (128, 64, 4, 1, 1)
294+
return (64, 64, 4, 1, 1)
295295
else:
296296
if tile_k <= 64:
297297
return (tile_m, 256, 4, 1, 1)

tests/benchmark_backward.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import List, Optional
2+
import traceback
23

34
import torch
45
from torch.nn.attention import sdpa_kernel, SDPBackend
@@ -269,6 +270,7 @@ def run_benchmark(cfg: BenchmarkConfig) -> BenchmarkResult:
269270
cudnn_dense_tflops=cudnn_tflops,
270271
)
271272
except Exception as exc:
273+
full_error = f"{exc}\n{traceback.format_exc()}"
272274
return BenchmarkResult(
273275
config=cfg,
274276
triton_dense_ms=None,
@@ -281,7 +283,7 @@ def run_benchmark(cfg: BenchmarkConfig) -> BenchmarkResult:
281283
triton_gated_tflops=None,
282284
fa_dense_tflops=None,
283285
cudnn_dense_tflops=None,
284-
error_message=str(exc),
286+
error_message=full_error,
285287
)
286288

287289

@@ -290,7 +292,7 @@ def print_results(results: List[BenchmarkResult]) -> None:
290292
if not ok:
291293
print("No successful benchmark results.")
292294
for r in results:
293-
print(f"Failed: {r.config} -> {r.error_message}")
295+
print(f"Failed: {r.config}\n{r.error_message}")
294296
return
295297

296298
rows = []

tests/benchmark_decode.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import List, Optional
2+
import traceback
23

34
import torch
45
from torch.nn.attention import sdpa_kernel, SDPBackend
@@ -209,6 +210,7 @@ def run_benchmark(cfg: BenchmarkConfig) -> BenchmarkResult:
209210
cudnn_dense_tflops=cudnn_dense_tflops,
210211
)
211212
except Exception as exc:
213+
full_error = f"{exc}\n{traceback.format_exc()}"
212214
return BenchmarkResult(
213215
config=cfg,
214216
triton_dense_ms=None,
@@ -221,7 +223,7 @@ def run_benchmark(cfg: BenchmarkConfig) -> BenchmarkResult:
221223
triton_gated_tflops=None,
222224
fa_dense_tflops=None,
223225
cudnn_dense_tflops=None,
224-
error_message=str(exc),
226+
error_message=full_error,
225227
)
226228

227229

@@ -230,7 +232,7 @@ def print_results(results: List[BenchmarkResult]) -> None:
230232
if not ok:
231233
print("No successful benchmark results.")
232234
for r in results:
233-
print(f"Failed: {r.config} -> {r.error_message}")
235+
print(f"Failed: {r.config}\n{r.error_message}")
234236
return
235237

236238
rows = []

tests/benchmark_forward.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import List, Optional
2+
import traceback
23

34
import torch
45
from torch.nn.attention import sdpa_kernel, SDPBackend
@@ -212,6 +213,7 @@ def run_benchmark(cfg: BenchmarkConfig) -> BenchmarkResult:
212213
cudnn_dense_tflops=cudnn_dense_tflops,
213214
)
214215
except Exception as exc:
216+
full_error = f"{exc}\n{traceback.format_exc()}"
215217
return BenchmarkResult(
216218
config=cfg,
217219
triton_dense_ms=None,
@@ -224,7 +226,7 @@ def run_benchmark(cfg: BenchmarkConfig) -> BenchmarkResult:
224226
triton_gated_tflops=None,
225227
fa_dense_tflops=None,
226228
cudnn_dense_tflops=None,
227-
error_message=str(exc),
229+
error_message=full_error,
228230
)
229231

230232

@@ -233,7 +235,7 @@ def print_results(results: List[BenchmarkResult]) -> None:
233235
if not ok:
234236
print("No successful benchmark results.")
235237
for r in results:
236-
print(f"Failed: {r.config} -> {r.error_message}")
238+
print(f"Failed: {r.config}\n{r.error_message}")
237239
return
238240

239241
rows = []

0 commit comments

Comments
 (0)