Skip to content

Commit fcaab70

Browse files
committed
small fix
1 parent dceb079 commit fcaab70

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

lightllm/models/llama/triton_kernel/gqa_flash_decoding_bib.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,6 @@ def _kernel_gqa_flash_decoding_bib_stage1(
116116
d_off = tl.arange(0, HEAD_DIM)
117117

118118
cur_batch = tl.load(chunk2batch_tensor + chunk_idx)
119-
cur_req = tl.load(b_req_idx_tensor + cur_batch)
120119
cur_seq_len = tl.load(b_seq_len_tensor + cur_batch)
121120
cur_start = tl.load(chunk2start_tensor + chunk_idx)
122121
cur_end = tl.minimum(cur_start + CHUNK_SIZE, cur_seq_len)
@@ -133,9 +132,10 @@ def _kernel_gqa_flash_decoding_bib_stage1(
133132
max_exp = tl.zeros([Q_GROUP_SIZE], dtype=tl.float32) - float("inf")
134133
accum = tl.zeros([Q_GROUP_SIZE, HEAD_DIM], dtype=tl.float32)
135134

136-
for block_idx in tl.range(0, cur_block_num, 1):
135+
for block_idx in range(0, cur_block_num, 1):
137136
block_range = cur_start + block_idx * BLOCK_N + tl.arange(0, BLOCK_N) # shape [BLOCK_N]
138137
block_mask = block_range < cur_end # shape [BLOCK_N]
138+
cur_req = tl.load(b_req_idx_tensor + cur_batch)
139139
cur_kv_loc = tl.load(
140140
req_to_token_idx_tensor
141141
+ cur_req * req_to_token_idx_stride_bs
@@ -156,10 +156,10 @@ def _kernel_gqa_flash_decoding_bib_stage1(
156156

157157
exp_logic = tl.exp(att - new_max[:, None])
158158
log_scale = tl.exp(max_exp - new_max)
159-
accum *= log_scale[:, None]
160159

161160
v_off = cur_kv_loc[:, None] * v_stride_token + kv_head_idx * v_stride_h + d_off[None, :] * v_stride_d
162161
v = tl.load(v_tensor + v_off, mask=block_mask[:, None], other=0.0)
162+
accum *= log_scale[:, None]
163163
accum += tl.dot(exp_logic.to(v.dtype), v)
164164

165165
sum_exp = sum_exp * log_scale + tl.sum(exp_logic, axis=1)
@@ -168,10 +168,10 @@ def _kernel_gqa_flash_decoding_bib_stage1(
168168
off_mid_o = (
169169
chunk_idx * mid_o_stride_chunk + cur_q_range[:, None] * mid_o_stride_h + d_off[None, :] * mid_o_stride_d
170170
) # shape [Q_GROUP_SIZE, HEAD_DIM]
171-
tl.store(mid_o_tensor + off_mid_o, accum, mask=cur_q_mask[:, None])
172171
off_mid_o_logexpsum = (
173172
chunk_idx * mid_o_logexpsum_stride_chunk + cur_q_range * mid_o_logexpsum_stride_h
174173
) # shape [Q_GROUP_SIZE, 1]
174+
tl.store(mid_o_tensor + off_mid_o, accum, mask=cur_q_mask[:, None])
175175
tl.store(mid_o_logexpsum_tensor + off_mid_o_logexpsum, sum_exp, mask=cur_q_mask)
176176

177177

@@ -197,6 +197,7 @@ def gqa_flash_decoding_bib_stage1(
197197
grid size: [chunk_num, kv_head_num]
198198
"""
199199
grid = (chunk_num, k.shape[1])
200+
assert chunk_size >= run_config["BLOCK_N"] and chunk_size % run_config["BLOCK_N"] == 0
200201
_kernel_gqa_flash_decoding_bib_stage1[grid](
201202
q,
202203
k,
@@ -333,7 +334,7 @@ def gqa_flash_decoding_bib(q, k, v, infer_state, out=None, alloc_tensor_func=tor
333334
out_dtype=q.dtype,
334335
)
335336
if not hasattr(infer_state, "bib_info"):
336-
chunk_size = run_config["BLOCK_N"]
337+
chunk_size = run_config.get("CHUNK_SIZE", run_config["BLOCK_N"])
337338

338339
# TODO: impl in triton
339340
b_seq_len = infer_state.b_seq_len

0 commit comments

Comments
 (0)