@@ -38,6 +38,37 @@ def _host_descriptor_pre_hook(nargs):
3838 "NUM_BUFFERS_QK" : 1 ,
3939 "NUM_MMA_GROUPS" : 2 ,
4040 "NUM_MMA_SLICES" : 2 ,
41+ "GROUP_SIZE_N" : 1 ,
42+ },
43+ num_stages = 0 ,
44+ num_warps = 4 ,
45+ pre_hook = _host_descriptor_pre_hook ,
46+ ),
47+ triton .Config (
48+ {
49+ "BLOCK_M" : 256 ,
50+ "BLOCK_N" : 128 ,
51+ "NUM_BUFFERS_Q" : 1 ,
52+ "NUM_BUFFERS_KV" : 3 ,
53+ "NUM_BUFFERS_QK" : 1 ,
54+ "NUM_MMA_GROUPS" : 2 ,
55+ "NUM_MMA_SLICES" : 2 ,
56+ "GROUP_SIZE_N" : 4 ,
57+ },
58+ num_stages = 0 ,
59+ num_warps = 4 ,
60+ pre_hook = _host_descriptor_pre_hook ,
61+ ),
62+ triton .Config (
63+ {
64+ "BLOCK_M" : 256 ,
65+ "BLOCK_N" : 128 ,
66+ "NUM_BUFFERS_Q" : 1 ,
67+ "NUM_BUFFERS_KV" : 6 ,
68+ "NUM_BUFFERS_QK" : 1 ,
69+ "NUM_MMA_GROUPS" : 2 ,
70+ "NUM_MMA_SLICES" : 2 ,
71+ "GROUP_SIZE_N" : 1 ,
4172 },
4273 num_stages = 0 ,
4374 num_warps = 4 ,
@@ -52,6 +83,7 @@ def _host_descriptor_pre_hook(nargs):
5283 "NUM_BUFFERS_QK" : 1 ,
5384 "NUM_MMA_GROUPS" : 2 ,
5485 "NUM_MMA_SLICES" : 2 ,
86+ "GROUP_SIZE_N" : 4 ,
5587 },
5688 num_stages = 0 ,
5789 num_warps = 4 ,
@@ -62,9 +94,13 @@ def _host_descriptor_pre_hook(nargs):
6294
6395def prune_configs_by_hdim (configs , named_args , ** kwargs ):
6496 HEAD_DIM = kwargs ["HEAD_DIM" ]
97+ STAGE = kwargs ["STAGE" ]
6598 target_kv_buffers = 6 if HEAD_DIM == 64 else 3
66- # Only match HEAD_DIM for BLOCK_N
67- return [conf for conf in configs if conf .kwargs .get ("NUM_BUFFERS_KV" , 0 ) == target_kv_buffers ]
99+ target_group_size_n = 4 if STAGE == 3 else 1
100+ return [
101+ conf for conf in configs if conf .kwargs .get ("NUM_BUFFERS_KV" , 0 ) == target_kv_buffers
102+ and conf .kwargs .get ("GROUP_SIZE_N" , 0 ) == target_group_size_n
103+ ]
68104
69105
70106@triton .jit
@@ -140,9 +176,21 @@ def _get_fused_loop_bounds(start_m, N_CTX, BLOCK_M, STAGE: tl.constexpr):
140176
141177
142178@triton .jit
143- def _compute_offsets (tile_idx , n_tile_num , H , N_CTX , BLOCK_M , STAGE : tl .constexpr ):
144- start_m = tile_idx % n_tile_num
145- off_hz = tile_idx // n_tile_num
179+ def _compute_offsets (
180+ tile_idx ,
181+ H ,
182+ num_pid_n ,
183+ num_pid_in_group ,
184+ N_CTX ,
185+ BLOCK_M : tl .constexpr ,
186+ STAGE : tl .constexpr ,
187+ GROUP_SIZE_N : tl .constexpr ,
188+ ):
189+ group_id = tile_idx // num_pid_in_group
190+ first_pid_n = group_id * GROUP_SIZE_N
191+ group_size_n = min (num_pid_n - first_pid_n , GROUP_SIZE_N )
192+ start_m = (tile_idx % num_pid_in_group ) // group_size_n
193+ off_hz = first_pid_n + (tile_idx % group_size_n )
146194 off_z = off_hz // H
147195 off_h = off_hz % H
148196 offset_y = off_z * (N_CTX * H ) + off_h * N_CTX
@@ -282,6 +330,7 @@ def _attn_fwd_ws(sm_scale, M, #
282330 NUM_BUFFERS_QK : tl .constexpr , #
283331 NUM_MMA_GROUPS : tl .constexpr , #
284332 NUM_MMA_SLICES : tl .constexpr , #
333+ GROUP_SIZE_N : tl .constexpr , #
285334 ):
286335 tl .static_assert (NUM_MMA_GROUPS == 2 )
287336 tl .static_assert (NUM_BUFFERS_QK == 1 )
@@ -292,10 +341,12 @@ def _attn_fwd_ws(sm_scale, M, #
292341 # original grid
293342 # triton.cdiv(q.shape[2], META["BLOCK_M"]),
294343 # q.shape[0] * q.shape[1],
295- n_tile_num = tl .cdiv (N_CTX , BLOCK_M )
296344 prog_id = tl .program_id (0 )
297345 num_progs = tl .num_programs (0 )
298- total_tiles = n_tile_num * Z * H
346+ num_pid_m = tl .cdiv (N_CTX , BLOCK_M )
347+ num_pid_n = Z * H
348+ num_pid_in_group = num_pid_m * GROUP_SIZE_N
349+ total_tiles = num_pid_m * Z * H
299350
300351 tiles_per_sm = total_tiles // num_progs
301352 if prog_id < total_tiles % num_progs :
@@ -375,7 +426,15 @@ def _attn_fwd_ws(sm_scale, M, #
375426 for i in range (0 , tiles_per_sm ):
376427 # initialize offsets
377428 start_m , off_hz , lo , hi , qo_offset_y , kv_offset_y = _compute_offsets (
378- tile_idx , n_tile_num , H , N_CTX , BLOCK_M , STAGE )
429+ tile_idx ,
430+ H ,
431+ num_pid_n ,
432+ num_pid_in_group ,
433+ N_CTX ,
434+ BLOCK_M ,
435+ STAGE ,
436+ GROUP_SIZE_N ,
437+ )
379438 for _ in tl .range (lo , hi , BLOCK_N ):
380439 _ , phase = _get_bufidx_phase (accum_cnt , 1 )
381440 for cid in tl .static_range (0 , NUM_MMA_GROUPS ):
@@ -439,7 +498,15 @@ def _attn_fwd_ws(sm_scale, M, #
439498 for i in range (0 , tiles_per_sm ):
440499 # initialize offsets
441500 start_m , off_hz , lo , hi , qo_offset_y , kv_offset_y = _compute_offsets (
442- tile_idx , n_tile_num , H , N_CTX , BLOCK_M , STAGE )
501+ tile_idx ,
502+ H ,
503+ num_pid_n ,
504+ num_pid_in_group ,
505+ N_CTX ,
506+ BLOCK_M ,
507+ STAGE ,
508+ GROUP_SIZE_N ,
509+ )
443510 # initialize pointer to m and l
444511 m_i = tl .zeros ([BLOCK_M_SPLIT ], dtype = tl .float32 ) - float ("inf" )
445512 l_i = tl .zeros ([BLOCK_M_SPLIT ], dtype = tl .float32 ) + 1.0
@@ -517,7 +584,16 @@ def _attn_fwd_ws(sm_scale, M, #
517584
518585 for j in range (0 , tiles_per_sm ):
519586 # initialize offsets
520- _ , _ , lo , hi , _ , _ = _compute_offsets (tile_idx , n_tile_num , H , N_CTX , BLOCK_M , STAGE )
587+ _ , _ , lo , hi , _ , _ = _compute_offsets (
588+ tile_idx ,
589+ H ,
590+ num_pid_n ,
591+ num_pid_in_group ,
592+ N_CTX ,
593+ BLOCK_M ,
594+ STAGE ,
595+ GROUP_SIZE_N ,
596+ )
521597
522598 q_bufIdx , q_phase = _get_bufidx_phase (j , NUM_BUFFERS_Q )
523599 k_bufIdx , k_phase = _get_bufidx_phase (accum_cnt_kv , NUM_BUFFERS_KV )
@@ -685,8 +761,16 @@ def _attn_fwd_ws(sm_scale, M, #
685761 accum_cnt_kv = 0
686762 for i in range (0 , tiles_per_sm ):
687763 # initialize offsets
688- _ , _ , lo , hi , qo_offset_y , kv_offset_y = _compute_offsets (tile_idx , n_tile_num , H , N_CTX , BLOCK_M ,
689- STAGE )
764+ _ , _ , lo , hi , qo_offset_y , kv_offset_y = _compute_offsets (
765+ tile_idx ,
766+ H ,
767+ num_pid_n ,
768+ num_pid_in_group ,
769+ N_CTX ,
770+ BLOCK_M ,
771+ STAGE ,
772+ GROUP_SIZE_N ,
773+ )
690774
691775 # load q0
692776 q_bufIdx , q_phase = _get_bufidx_phase (i , NUM_BUFFERS_Q )
@@ -758,7 +842,16 @@ def _attn_fwd_ws(sm_scale, M, #
758842 # initialize offsets
759843 for i in range (0 , tiles_per_sm ):
760844 # initialize offsets
761- _ , _ , _ , _ , qo_offset_y , _ = _compute_offsets (tile_idx , n_tile_num , H , N_CTX , BLOCK_M , STAGE )
845+ _ , _ , _ , _ , qo_offset_y , _ = _compute_offsets (
846+ tile_idx ,
847+ H ,
848+ num_pid_n ,
849+ num_pid_in_group ,
850+ N_CTX ,
851+ BLOCK_M ,
852+ STAGE ,
853+ GROUP_SIZE_N ,
854+ )
762855 _ , phase = _get_bufidx_phase (i , 1 )
763856 for cid in tl .static_range (0 , NUM_MMA_GROUPS ):
764857 tlx .barrier_wait (o_fulls [cid ], phase )
0 commit comments