@@ -36,19 +36,18 @@ def _fwd_kernel_flash_decode_stage1(
3636 stride_mid_o_eb ,
3737 stride_mid_o_eh ,
3838 stride_mid_o_es ,
39- gqa_group_size ,
4039 Q_HEAD_NUM : tl .constexpr ,
4140 BLOCK_SEQ : tl .constexpr ,
4241 BLOCK_DMODEL : tl .constexpr ,
4342 BLOCK_ROPE_DMODEL : tl .constexpr ,
4443 BLOCK_N : tl .constexpr ,
4544):
46- cur_batch = tl .program_id (0 )
47- cur_kv_head = tl .program_id (1 )
48- seq_start_block = tl .program_id (2 )
45+ seq_start_block = tl .program_id (0 )
46+ cur_q_head = tl .program_id (1 )
47+ cur_batch = tl .program_id (2 )
4948
5049 cur_q_head_offs = tl .arange (0 , Q_HEAD_NUM )
51- cur_q_head_range = cur_kv_head * gqa_group_size + cur_q_head_offs
50+ cur_q_head_range = cur_q_head * Q_HEAD_NUM + cur_q_head_offs
5251
5352 offs_d = tl .arange (0 , BLOCK_DMODEL )
5453 offs_rope_d = tl .arange (0 , BLOCK_ROPE_DMODEL )
@@ -59,7 +58,8 @@ def _fwd_kernel_flash_decode_stage1(
5958
6059 off_q = cur_batch * stride_q_bs + cur_q_head_range [:, None ] * stride_q_h + offs_d [None , :]
6160 off_rope_q = cur_batch * stride_q_rope_bs + cur_q_head_range [:, None ] * stride_q_rope_h + offs_rope_d [None , :]
62-
61+ q = tl .load (Q_nope + off_q )
62+ q_rope = tl .load (Q_rope + off_rope_q )
6363 block_n_size = (
6464 tl .where (
6565 cur_batch_end_index - cur_batch_start_index <= 0 ,
@@ -70,27 +70,20 @@ def _fwd_kernel_flash_decode_stage1(
7070 )
7171
7272 offs_n = cur_batch_start_index + tl .arange (0 , BLOCK_N )
73-
74- q = tl .load (Q_nope + off_q , mask = cur_q_head_range [:, None ] < (cur_kv_head + 1 ) * gqa_group_size , other = 0.0 )
75- q_rope = tl .load (
76- Q_rope + off_rope_q , mask = cur_q_head_range [:, None ] < (cur_kv_head + 1 ) * gqa_group_size , other = 0.0
77- )
78-
7973 sum_exp = tl .zeros ([Q_HEAD_NUM ], dtype = tl .float32 )
8074 max_logic = tl .zeros ([Q_HEAD_NUM ], dtype = tl .float32 ) - float ("inf" )
8175 acc = tl .zeros ([Q_HEAD_NUM , BLOCK_DMODEL ], dtype = tl .float32 )
82-
8376 for start_n in range (0 , block_n_size , 1 ):
8477 offs_n_new = start_n * BLOCK_N + offs_n
8578 kv_loc = tl .load (
8679 Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new ,
8780 mask = offs_n_new < cur_batch_end_index ,
8881 other = 0 ,
8982 )
90- off_kv = kv_loc [None , :] * stride_kv_bs + cur_kv_head * stride_kv_h + offs_d [:, None ]
83+ off_kv = kv_loc [None , :] * stride_kv_bs + offs_d [:, None ]
9184 kv = tl .load (KV_nope + off_kv , mask = offs_n_new [None , :] < cur_batch_end_index , other = 0.0 )
9285 att_value = tl .dot (q , kv )
93- off_rope_kv = kv_loc [None , :] * stride_kv_rope_bs + cur_kv_head * stride_kv_rope_h + offs_rope_d [:, None ]
86+ off_rope_kv = kv_loc [None , :] * stride_kv_rope_bs + offs_rope_d [:, None ]
9487 rope_kv = tl .load (KV_rope + off_rope_kv , mask = offs_n_new [None , :] < cur_batch_end_index , other = 0.0 )
9588 att_value += tl .dot (q_rope , rope_kv )
9689
@@ -120,12 +113,10 @@ def _fwd_kernel_flash_decode_stage1(
120113 tl .store (
121114 Mid_O + off_mid_o ,
122115 acc / sum_exp [:, None ],
123- mask = cur_q_head_range [:, None ] < (cur_kv_head + 1 ) * gqa_group_size ,
124116 )
125117 tl .store (
126118 Mid_O_LogExpSum + off_mid_o_logexpsum ,
127119 max_logic + tl .log (sum_exp ),
128- mask = cur_q_head_range < (cur_kv_head + 1 ) * gqa_group_size ,
129120 )
130121 return
131122
@@ -147,6 +138,7 @@ def flash_decode_stage1(
147138):
148139 BLOCK_SEQ = block_seq
149140 BLOCK_N = 16
141+ BLOCK_Q_HEAD = 16
150142 assert BLOCK_SEQ % BLOCK_N == 0
151143 # shape constraints
152144 q_nope_dim = q_nope .shape [- 1 ]
@@ -158,9 +150,9 @@ def flash_decode_stage1(
158150 assert q_rope_dim in {16 , 32 , 64 , 128 , 256 }
159151
160152 sm_scale = softmax_scale # 计算scale系数
161- batch , kv_head_num = B_req_idx .shape [0 ], kv_nope .shape [1 ]
162- grid = ( batch , kv_head_num , triton . cdiv ( max_len_in_batch , BLOCK_SEQ ))
163- gqa_group_size = q_nope . shape [ 1 ] // kv_nope . shape [ 1 ]
153+ batch , q_head_num = B_req_idx .shape [0 ], q_nope .shape [1 ]
154+ assert q_head_num % BLOCK_Q_HEAD == 0
155+ grid = ( triton . cdiv ( max_len_in_batch , BLOCK_SEQ ), q_head_num // BLOCK_Q_HEAD , batch )
164156
165157 _fwd_kernel_flash_decode_stage1 [grid ](
166158 q_nope ,
@@ -194,13 +186,12 @@ def flash_decode_stage1(
194186 mid_out_logsumexp .stride (0 ),
195187 mid_out_logsumexp .stride (1 ),
196188 mid_out_logsumexp .stride (2 ),
197- gqa_group_size ,
198- Q_HEAD_NUM = max (16 , triton .next_power_of_2 (gqa_group_size )),
189+ Q_HEAD_NUM = q_head_num ,
199190 BLOCK_SEQ = BLOCK_SEQ ,
200191 BLOCK_DMODEL = q_nope_dim ,
201192 BLOCK_ROPE_DMODEL = q_rope_dim ,
202193 BLOCK_N = BLOCK_N ,
203- num_warps = 2 ,
194+ num_warps = 4 ,
204195 num_stages = 2 ,
205196 )
206197 return
0 commit comments