Skip to content

Commit 0424884

Browse files
committed
Update launch configuration logic for forward and sparse kernels to improve occupancy
1 parent 1e7d6d6 commit 0424884

File tree

1 file changed

+18
-18
lines changed

1 file changed

+18
-18
lines changed

flash_sparse_attn/ops/triton/launch_template.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@ def get_fwd_dense_launch_config(
3030
if device.type == "cuda":
3131
# If split KV, we set tile_m based on qheads_per_kvhead to ensure good occupancy
3232
if is_split_kv:
33-
if pack_gqa and qheads_per_kvhead > 1:
33+
if pack_gqa and qheads_per_kvhead > 16:
3434
tile_m = triton.next_power_of_2(qheads_per_kvhead)
3535
else:
36-
tile_m = 1
36+
tile_m = 16
3737
else:
3838
# will be set based on architecture and tile_k
3939
tile_m = None
@@ -63,13 +63,13 @@ def get_fwd_dense_launch_config(
6363
elif arch // 10 == 9:
6464
if not is_split_kv:
6565
if tile_k <= 64:
66-
return (256, 128, 4, 1, 1)
67-
elif tile_k <= 128:
6866
return (128, 128, 4, 1, 1)
69-
elif tile_k <= 256:
67+
elif tile_k <= 128:
7068
return (128, 64, 4, 1, 1)
69+
elif tile_k <= 256:
70+
return (64, 64, 4, 1, 1)
7171
else:
72-
return (128, 64, 4, 1, 1)
72+
return (64, 64, 4, 1, 1)
7373
else:
7474
if tile_k <= 64:
7575
return (tile_m, 256, 4, 1, 1)
@@ -141,10 +141,10 @@ def get_fwd_sparse_launch_config(
141141
if device.type == "cuda":
142142
# If split KV, we set tile_m based on qheads_per_kvhead to ensure good occupancy
143143
if is_split_kv:
144-
if pack_gqa and qheads_per_kvhead > 1:
144+
if pack_gqa and qheads_per_kvhead > 16:
145145
tile_m = triton.next_power_of_2(qheads_per_kvhead)
146146
else:
147-
tile_m = 1
147+
tile_m = 16
148148
else:
149149
# will be set based on architecture and tile_k
150150
tile_m = None
@@ -174,13 +174,13 @@ def get_fwd_sparse_launch_config(
174174
elif arch // 10 == 9:
175175
if not is_split_kv:
176176
if tile_k <= 64:
177-
return (256, 128, 4, 1, 1)
178-
elif tile_k <= 128:
179177
return (128, 128, 4, 1, 1)
180-
elif tile_k <= 256:
178+
elif tile_k <= 128:
181179
return (128, 64, 4, 1, 1)
180+
elif tile_k <= 256:
181+
return (64, 64, 4, 1, 1)
182182
else:
183-
return (128, 64, 4, 1, 1)
183+
return (64, 64, 4, 1, 1)
184184
else:
185185
if tile_k <= 64:
186186
return (tile_m, 256, 4, 1, 1)
@@ -252,10 +252,10 @@ def get_fwd_gated_launch_config(
252252
if device.type == "cuda":
253253
# If split KV, we set tile_m based on qheads_per_kvhead to ensure good occupancy
254254
if is_split_kv:
255-
if pack_gqa and qheads_per_kvhead > 1:
255+
if pack_gqa and qheads_per_kvhead > 16:
256256
tile_m = triton.next_power_of_2(qheads_per_kvhead)
257257
else:
258-
tile_m = 1
258+
tile_m = 16
259259
else:
260260
# will be set based on architecture and tile_k
261261
tile_m = None
@@ -285,13 +285,13 @@ def get_fwd_gated_launch_config(
285285
elif arch // 10 == 9:
286286
if not is_split_kv:
287287
if tile_k <= 64:
288-
return (256, 128, 4, 1, 1)
289-
elif tile_k <= 128:
290288
return (128, 128, 4, 1, 1)
291-
elif tile_k <= 256:
289+
elif tile_k <= 128:
292290
return (128, 64, 4, 1, 1)
291+
elif tile_k <= 256:
292+
return (64, 64, 4, 1, 1)
293293
else:
294-
return (128, 64, 4, 1, 1)
294+
return (64, 64, 4, 1, 1)
295295
else:
296296
if tile_k <= 64:
297297
return (tile_m, 256, 4, 1, 1)

0 commit comments

Comments
 (0)