Skip to content

Commit 843d211

Browse files
committed
update
1 parent f64ce3b commit 843d211

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ def _fwd_kernel_flash_decode_stage2(
2222
BLOCK_SEQ: tl.constexpr,
2323
BLOCK_DMODEL: tl.constexpr,
2424
):
25-
cur_batch = tl.program_id(0)
26-
cur_head = tl.program_id(1)
25+
cur_head = tl.program_id(0)
26+
cur_batch = tl.program_id(1)
2727

2828
offs_d = tl.arange(0, BLOCK_DMODEL)
2929
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
@@ -57,7 +57,7 @@ def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, Out, block_seq):
5757
Lk = mid_out.shape[-1]
5858
assert Lk in {16, 32, 64, 128, 256, 512}
5959
batch, head_num = mid_out.shape[0], mid_out.shape[1]
60-
grid = (batch, head_num)
60+
grid = (head_num, batch)
6161

6262
_fwd_kernel_flash_decode_stage2[grid](
6363
B_Seqlen,

0 commit comments

Comments
 (0)