Skip to content

Commit d499f90

Browse files
committed
update sliu grid
1 parent a9e0156 commit d499f90

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

lightllm/common/fused_moe/moe_silu_and_mul.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ def _silu_and_mul_kernel_fast(
2323
stride_input_m = tl.cast(stride_input_m, dtype=tl.int64)
2424
stride_output_m = tl.cast(stride_output_m, dtype=tl.int64)
2525

26-
m_block_index = tl.program_id(0)
27-
n_block_index = tl.program_id(1)
26+
n_block_index = tl.program_id(0)
27+
m_block_index = tl.program_id(1)
2828
n_offsets = n_block_index * BLOCK_N + tl.arange(0, BLOCK_N)
2929
m_start_index = m_block_index * BLOCK_M
3030
m_end_index = (m_block_index + 1) * BLOCK_M
@@ -82,8 +82,8 @@ def silu_and_mul_fwd(input: torch.Tensor, output: torch.Tensor, **run_config):
8282
NUM_STAGES = run_config["NUM_STAGES"]
8383

8484
grid = (
85-
triton.cdiv(size_m, BLOCK_M),
8685
triton.cdiv(size_n, BLOCK_N),
86+
triton.cdiv(size_m, BLOCK_M),
8787
)
8888
NEED_MASK = (size_n % BLOCK_N) != 0
8989
_silu_and_mul_kernel_fast[grid](

0 commit comments

Comments
 (0)