Skip to content

Commit 94b8db8

Browse files
author
none
committed
fix
1 parent f7bd1c5 commit 94b8db8

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
)
1818
from lightllm.common.fused_moe.deepep_scatter_gather import ep_scatter, ep_gather
1919
from lightllm.common.basemodel.triton_kernel.redundancy_topk_ids_repair import redundancy_topk_ids_repair
20+
from lightllm.utils.envs_utils import is_triton_autotune_enabled
2021
from lightllm.utils.log_utils import init_logger
2122

2223
logger = init_logger(__name__)
@@ -353,6 +354,15 @@ def prefilled_group_gemm(
353354
)
354355
# gather and local reduce
355356
ep_gather(gemm_out_b, recv_topk_idx, recv_topk_weights, output_index, gather_out)
357+
else:
358+
######################################## warning ##################################################
359+
# here is used to match autotune feature, make moe model run same triton kernel in different rank.
360+
# in some special case, one rank will recv 0 token, so add a token to make it run triton kernel.
361+
if is_triton_autotune_enabled():
362+
_gemm_out_a = torch.zeros((1, N), device=device, dtype=hidden_dtype)
363+
_silu_out = torch.zeros((1, N // 2), device=device, dtype=hidden_dtype)
364+
silu_and_mul_fwd(_gemm_out_a.view(-1, N), _silu_out)
365+
_gemm_out_a, _silu_out = None, None
356366

357367
return gather_out
358368

lightllm/common/fused_moe/grouped_fused_moe_ep.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,8 +192,8 @@ def fused_experts_impl(
192192
# here is used to match autotune feature, make moe model run same triton kernel in different rank.
193193
# in some special case, one rank will recv 0 token, so add a token to make it run triton kernel.
194194
if is_triton_autotune_enabled():
195-
_gemm_out_a = torch.empty((1, N), device=hidden_states.device, dtype=hidden_states.dtype)
196-
_silu_out = torch.empty((1, N // 2), device=hidden_states.device, dtype=hidden_states.dtype)
195+
_gemm_out_a = torch.zeros((1, N), device=hidden_states.device, dtype=hidden_states.dtype)
196+
_silu_out = torch.zeros((1, N // 2), device=hidden_states.device, dtype=hidden_states.dtype)
197197
silu_and_mul_fwd(_gemm_out_a.view(-1, N), _silu_out)
198198
_gemm_out_a, _silu_out = None, None
199199

0 commit comments

Comments
 (0)