Skip to content

Commit f7bd1c5

Browse files
author
none
committed
fix
1 parent bf1046d commit f7bd1c5

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

lightllm/common/fused_moe/grouped_fused_moe_ep.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
)
1515
from lightllm.common.fused_moe.deepep_scatter_gather import ep_scatter, ep_gather
1616
from lightllm.utils.envs_utils import get_deepep_num_max_dispatch_tokens_per_rank
17+
from lightllm.utils.envs_utils import is_triton_autotune_enabled
1718
import numpy as np
1819

1920
logger = init_logger(__name__)
@@ -186,6 +187,16 @@ def fused_experts_impl(
186187

187188
# gather and local reduce
188189
ep_gather(gemm_out_b, recv_topk_idx, recv_topk_weights, output_index, gather_out)
190+
else:
191+
######################################## warning ##################################################
192+
# here is used to match autotune feature, make moe model run same triton kernel in different rank.
193+
# in some special case, one rank will recv 0 token, so add a token to make it run triton kernel.
194+
if is_triton_autotune_enabled():
195+
_gemm_out_a = torch.empty((1, N), device=hidden_states.device, dtype=hidden_states.dtype)
196+
_silu_out = torch.empty((1, N // 2), device=hidden_states.device, dtype=hidden_states.dtype)
197+
silu_and_mul_fwd(_gemm_out_a.view(-1, N), _silu_out)
198+
_gemm_out_a, _silu_out = None, None
199+
189200
# normal combine
190201
combined_x, _, event = buffer.combine(
191202
gather_out,

0 commit comments

Comments
 (0)