@@ -189,25 +189,23 @@ def _fwd_kernel_flash_decode_diverse_stage1(
189189 sum_exp = sum_exp * logic_scale + tl .sum (exp_logic , axis = 1 )
190190 max_logic = new_max_logic
191191
192- need_store = tl .where (block_n_size == 0 , 0 , 1 )
193- for _ in range (0 , need_store , 1 ):
194- off_mid_o = (
195- offs_batch [:, None , None ] * stride_mid_ob
196- + cur_q_head_range [None , :, None ] * stride_mid_oh
197- + seq_start_block * stride_mid_os
198- + offs_d [None , None , :]
199- )
200- off_mid_o_logexpsum = (
201- offs_batch [:, None ] * stride_mid_o_eb + cur_q_head_range [None , :] * stride_mid_o_eh + seq_start_block
202- )
203- tl .store (
204- Mid_O + off_mid_o ,
205- (acc / sum_exp [:, None ]).reshape (BLOCK_BATCH , BLOCK_HEAD , BLOCK_HEADDIM ),
206- )
207- tl .store (
208- Mid_O_LogExpSum + off_mid_o_logexpsum ,
209- (max_logic + tl .log (sum_exp )).reshape (BLOCK_BATCH , BLOCK_HEAD ),
210- )
192+ off_mid_o = (
193+ offs_batch [:, None , None ] * stride_mid_ob
194+ + cur_q_head_range [None , :, None ] * stride_mid_oh
195+ + seq_start_block * stride_mid_os
196+ + offs_d [None , None , :]
197+ )
198+ off_mid_o_logexpsum = (
199+ offs_batch [:, None ] * stride_mid_o_eb + cur_q_head_range [None , :] * stride_mid_o_eh + seq_start_block
200+ )
201+ tl .store (
202+ Mid_O + off_mid_o ,
203+ (acc / sum_exp [:, None ]).reshape (BLOCK_BATCH , BLOCK_HEAD , BLOCK_HEADDIM ),
204+ )
205+ tl .store (
206+ Mid_O_LogExpSum + off_mid_o_logexpsum ,
207+ (max_logic + tl .log (sum_exp )).reshape (BLOCK_BATCH , BLOCK_HEAD ),
208+ )
211209 return
212210
213211
0 commit comments