Skip to content

Commit 4beacdf

Browse files
author
wangzaijun
committed
fix sb bug.
1 parent 7cf4e62 commit 4beacdf

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

lightllm/common/fused_moe/grouped_fused_moe.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,7 @@ def grouped_matmul(
387387

388388
expert_num, n, k = expert_weights.shape
389389
assert token_inputs.shape[1] == k
390+
assert expert_to_weights_scale.shape[0] == expert_num
390391
assert expert_to_token_index.shape == expert_to_weights.shape
391392
assert token_inputs.is_contiguous()
392393
assert expert_to_token_num.is_contiguous()
@@ -520,7 +521,7 @@ def fused_experts_impl(
520521

521522
intermediate_cache1 = alloc_tensor_func((M, topk_num, N), device=hidden_states.device, dtype=hidden_states.dtype)
522523
intermediate_cache2 = alloc_tensor_func(
523-
(M * topk_num, N // 2), device=hidden_states.device, dtype=hidden_states.dtype
524+
(M, topk_num, N // 2), device=hidden_states.device, dtype=hidden_states.dtype
524525
)
525526
intermediate_cache3 = alloc_tensor_func(
526527
(M, topk_num, w2.shape[1]), device=hidden_states.device, dtype=hidden_states.dtype
@@ -567,10 +568,10 @@ def fused_experts_impl(
567568
**run_config,
568569
)
569570

570-
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
571+
ops.silu_and_mul(intermediate_cache2.view(-1, N // 2), intermediate_cache1.view(-1, N))
571572

572573
grouped_matmul(
573-
intermediate_cache2,
574+
intermediate_cache2.view(-1, N // 2),
574575
a2_scale,
575576
expert_to_token_num,
576577
expert_to_tokens,

0 commit comments

Comments
 (0)