@@ -116,7 +116,6 @@ def _kernel_gqa_flash_decoding_bib_stage1(
116116 d_off = tl .arange (0 , HEAD_DIM )
117117
118118 cur_batch = tl .load (chunk2batch_tensor + chunk_idx )
119- cur_req = tl .load (b_req_idx_tensor + cur_batch )
120119 cur_seq_len = tl .load (b_seq_len_tensor + cur_batch )
121120 cur_start = tl .load (chunk2start_tensor + chunk_idx )
122121 cur_end = tl .minimum (cur_start + CHUNK_SIZE , cur_seq_len )
@@ -133,9 +132,10 @@ def _kernel_gqa_flash_decoding_bib_stage1(
133132 max_exp = tl .zeros ([Q_GROUP_SIZE ], dtype = tl .float32 ) - float ("inf" )
134133 accum = tl .zeros ([Q_GROUP_SIZE , HEAD_DIM ], dtype = tl .float32 )
135134
136- for block_idx in tl . range (0 , cur_block_num , 1 ):
135+ for block_idx in range (0 , cur_block_num , 1 ):
137136 block_range = cur_start + block_idx * BLOCK_N + tl .arange (0 , BLOCK_N ) # shape [BLOCK_N]
138137 block_mask = block_range < cur_end # shape [BLOCK_N]
138+ cur_req = tl .load (b_req_idx_tensor + cur_batch )
139139 cur_kv_loc = tl .load (
140140 req_to_token_idx_tensor
141141 + cur_req * req_to_token_idx_stride_bs
@@ -156,10 +156,10 @@ def _kernel_gqa_flash_decoding_bib_stage1(
156156
157157 exp_logic = tl .exp (att - new_max [:, None ])
158158 log_scale = tl .exp (max_exp - new_max )
159- accum *= log_scale [:, None ]
160159
161160 v_off = cur_kv_loc [:, None ] * v_stride_token + kv_head_idx * v_stride_h + d_off [None , :] * v_stride_d
162161 v = tl .load (v_tensor + v_off , mask = block_mask [:, None ], other = 0.0 )
162+ accum *= log_scale [:, None ]
163163 accum += tl .dot (exp_logic .to (v .dtype ), v )
164164
165165 sum_exp = sum_exp * log_scale + tl .sum (exp_logic , axis = 1 )
@@ -168,10 +168,10 @@ def _kernel_gqa_flash_decoding_bib_stage1(
168168 off_mid_o = (
169169 chunk_idx * mid_o_stride_chunk + cur_q_range [:, None ] * mid_o_stride_h + d_off [None , :] * mid_o_stride_d
170170 ) # shape [Q_GROUP_SIZE, HEAD_DIM]
171- tl .store (mid_o_tensor + off_mid_o , accum , mask = cur_q_mask [:, None ])
172171 off_mid_o_logexpsum = (
173172 chunk_idx * mid_o_logexpsum_stride_chunk + cur_q_range * mid_o_logexpsum_stride_h
174173 ) # shape [Q_GROUP_SIZE, 1]
174+ tl .store (mid_o_tensor + off_mid_o , accum , mask = cur_q_mask [:, None ])
175175 tl .store (mid_o_logexpsum_tensor + off_mid_o_logexpsum , sum_exp , mask = cur_q_mask )
176176
177177
@@ -197,6 +197,7 @@ def gqa_flash_decoding_bib_stage1(
197197 grid size: [chunk_num, kv_head_num]
198198 """
199199 grid = (chunk_num , k .shape [1 ])
200+ assert chunk_size >= run_config ["BLOCK_N" ] and chunk_size % run_config ["BLOCK_N" ] == 0
200201 _kernel_gqa_flash_decoding_bib_stage1 [grid ](
201202 q ,
202203 k ,
@@ -333,7 +334,7 @@ def gqa_flash_decoding_bib(q, k, v, infer_state, out=None, alloc_tensor_func=tor
333334 out_dtype = q .dtype ,
334335 )
335336 if not hasattr (infer_state , "bib_info" ):
336- chunk_size = run_config ["BLOCK_N" ]
337+ chunk_size = run_config . get ( "CHUNK_SIZE" , run_config ["BLOCK_N" ])
337338
338339 # TODO: impl in triton
339340 b_seq_len = infer_state .b_seq_len
0 commit comments