Skip to content

Commit 04e5490

Browse files
committed
fix
1 parent c5bb0d0 commit 04e5490

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

lightllm/models/llama/triton_kernel/ppl_int8kv_flash_decoding_diverse_stage1.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ def _fwd_kernel_flash_decode_stage1(
4747
shared_batch_group_size = tl.load(b_mark_shared_group + cur_batch)
4848
if shared_batch_group_size == 0:
4949
return
50-
cur_batch = cur_batch - shared_batch_group_size
50+
cur_batch_end = cur_batch + 1
51+
cur_batch = cur_batch - (shared_batch_group_size - 1)
5152
cur_kv_head = tl.program_id(1)
5253
seq_start_block = tl.program_id(2)
5354

@@ -62,7 +63,7 @@ def _fwd_kernel_flash_decode_stage1(
6263
cur_batch_end_index = tl.minimum(cur_batch_seq_len, cur_batch_start_index + BLOCK_SEQ)
6364

6465
offs_batch = cur_batch + tl.arange(0, BLOCK_BATCH)
65-
offs_batch = tl.where(offs_batch < cur_batch + shared_batch_group_size, offs_batch, cur_batch)
66+
offs_batch = tl.where(offs_batch < cur_batch_end, offs_batch, cur_batch)
6667

6768
off_q = offs_batch[:, None, None] * stride_qbs + cur_q_head_range[None, :, None] * stride_qh + offs_d[None, None, :]
6869

0 commit comments

Comments
 (0)