Skip to content

Commit 89cc027

Browse files
author
wangzaijun
committed
fix
1 parent f56fca4 commit 89cc027

File tree

1 file changed

+17
-19
lines changed

1 file changed

+17
-19
lines changed

lightllm/models/llama/triton_kernel/ppl_int8kv_flash_decoding_diverse_stage1.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -189,25 +189,23 @@ def _fwd_kernel_flash_decode_diverse_stage1(
189189
sum_exp = sum_exp * logic_scale + tl.sum(exp_logic, axis=1)
190190
max_logic = new_max_logic
191191

192-
need_store = tl.where(block_n_size == 0, 0, 1)
193-
for _ in range(0, need_store, 1):
194-
off_mid_o = (
195-
offs_batch[:, None, None] * stride_mid_ob
196-
+ cur_q_head_range[None, :, None] * stride_mid_oh
197-
+ seq_start_block * stride_mid_os
198-
+ offs_d[None, None, :]
199-
)
200-
off_mid_o_logexpsum = (
201-
offs_batch[:, None] * stride_mid_o_eb + cur_q_head_range[None, :] * stride_mid_o_eh + seq_start_block
202-
)
203-
tl.store(
204-
Mid_O + off_mid_o,
205-
(acc / sum_exp[:, None]).reshape(BLOCK_BATCH, BLOCK_HEAD, BLOCK_HEADDIM),
206-
)
207-
tl.store(
208-
Mid_O_LogExpSum + off_mid_o_logexpsum,
209-
(max_logic + tl.log(sum_exp)).reshape(BLOCK_BATCH, BLOCK_HEAD),
210-
)
192+
off_mid_o = (
193+
offs_batch[:, None, None] * stride_mid_ob
194+
+ cur_q_head_range[None, :, None] * stride_mid_oh
195+
+ seq_start_block * stride_mid_os
196+
+ offs_d[None, None, :]
197+
)
198+
off_mid_o_logexpsum = (
199+
offs_batch[:, None] * stride_mid_o_eb + cur_q_head_range[None, :] * stride_mid_o_eh + seq_start_block
200+
)
201+
tl.store(
202+
Mid_O + off_mid_o,
203+
(acc / sum_exp[:, None]).reshape(BLOCK_BATCH, BLOCK_HEAD, BLOCK_HEADDIM),
204+
)
205+
tl.store(
206+
Mid_O_LogExpSum + off_mid_o_logexpsum,
207+
(max_logic + tl.log(sum_exp)).reshape(BLOCK_BATCH, BLOCK_HEAD),
208+
)
211209
return
212210

213211

0 commit comments

Comments
 (0)