@@ -30,10 +30,10 @@ def get_fwd_dense_launch_config(
3030 if device .type == "cuda" :
3131 # If split KV, we set tile_m based on qheads_per_kvhead to ensure good occupancy
3232 if is_split_kv :
33- if pack_gqa and qheads_per_kvhead > 1 :
33+ if pack_gqa and qheads_per_kvhead > 16 :
3434 tile_m = triton .next_power_of_2 (qheads_per_kvhead )
3535 else :
36- tile_m = 1
36+ tile_m = 16
3737 else :
3838 # will be set based on architecture and tile_k
3939 tile_m = None
@@ -63,13 +63,13 @@ def get_fwd_dense_launch_config(
6363 elif arch // 10 == 9 :
6464 if not is_split_kv :
6565 if tile_k <= 64 :
66- return (256 , 128 , 4 , 1 , 1 )
67- elif tile_k <= 128 :
6866 return (128 , 128 , 4 , 1 , 1 )
69- elif tile_k <= 256 :
67+ elif tile_k <= 128 :
7068 return (128 , 64 , 4 , 1 , 1 )
69+ elif tile_k <= 256 :
70+ return (64 , 64 , 4 , 1 , 1 )
7171 else :
72- return (128 , 64 , 4 , 1 , 1 )
72+ return (64 , 64 , 4 , 1 , 1 )
7373 else :
7474 if tile_k <= 64 :
7575 return (tile_m , 256 , 4 , 1 , 1 )
@@ -141,10 +141,10 @@ def get_fwd_sparse_launch_config(
141141 if device .type == "cuda" :
142142 # If split KV, we set tile_m based on qheads_per_kvhead to ensure good occupancy
143143 if is_split_kv :
144- if pack_gqa and qheads_per_kvhead > 1 :
144+ if pack_gqa and qheads_per_kvhead > 16 :
145145 tile_m = triton .next_power_of_2 (qheads_per_kvhead )
146146 else :
147- tile_m = 1
147+ tile_m = 16
148148 else :
149149 # will be set based on architecture and tile_k
150150 tile_m = None
@@ -174,13 +174,13 @@ def get_fwd_sparse_launch_config(
174174 elif arch // 10 == 9 :
175175 if not is_split_kv :
176176 if tile_k <= 64 :
177- return (256 , 128 , 4 , 1 , 1 )
178- elif tile_k <= 128 :
179177 return (128 , 128 , 4 , 1 , 1 )
180- elif tile_k <= 256 :
178+ elif tile_k <= 128 :
181179 return (128 , 64 , 4 , 1 , 1 )
180+ elif tile_k <= 256 :
181+ return (64 , 64 , 4 , 1 , 1 )
182182 else :
183- return (128 , 64 , 4 , 1 , 1 )
183+ return (64 , 64 , 4 , 1 , 1 )
184184 else :
185185 if tile_k <= 64 :
186186 return (tile_m , 256 , 4 , 1 , 1 )
@@ -252,10 +252,10 @@ def get_fwd_gated_launch_config(
252252 if device .type == "cuda" :
253253 # If split KV, we set tile_m based on qheads_per_kvhead to ensure good occupancy
254254 if is_split_kv :
255- if pack_gqa and qheads_per_kvhead > 1 :
255+ if pack_gqa and qheads_per_kvhead > 16 :
256256 tile_m = triton .next_power_of_2 (qheads_per_kvhead )
257257 else :
258- tile_m = 1
258+ tile_m = 16
259259 else :
260260 # will be set based on architecture and tile_k
261261 tile_m = None
@@ -285,13 +285,13 @@ def get_fwd_gated_launch_config(
285285 elif arch // 10 == 9 :
286286 if not is_split_kv :
287287 if tile_k <= 64 :
288- return (256 , 128 , 4 , 1 , 1 )
289- elif tile_k <= 128 :
290288 return (128 , 128 , 4 , 1 , 1 )
291- elif tile_k <= 256 :
289+ elif tile_k <= 128 :
292290 return (128 , 64 , 4 , 1 , 1 )
291+ elif tile_k <= 256 :
292+ return (64 , 64 , 4 , 1 , 1 )
293293 else :
294- return (128 , 64 , 4 , 1 , 1 )
294+ return (64 , 64 , 4 , 1 , 1 )
295295 else :
296296 if tile_k <= 64 :
297297 return (tile_m , 256 , 4 , 1 , 1 )
0 commit comments