|
14 | 14 | ) |
15 | 15 | from lightllm.common.fused_moe.deepep_scatter_gather import ep_scatter, ep_gather |
16 | 16 | from lightllm.utils.envs_utils import get_deepep_num_max_dispatch_tokens_per_rank |
| 17 | +from lightllm.utils.envs_utils import is_triton_autotune_enabled |
17 | 18 | import numpy as np |
18 | 19 |
|
19 | 20 | logger = init_logger(__name__) |
@@ -186,6 +187,16 @@ def fused_experts_impl( |
186 | 187 |
|
187 | 188 | # gather and local reduce |
188 | 189 | ep_gather(gemm_out_b, recv_topk_idx, recv_topk_weights, output_index, gather_out) |
| 190 | + else: |
| 191 | + ######################################## warning ################################################## |
| 192 | + # here is used to match autotune feature, make moe model run same triton kernel in different rank. |
| 193 | + # in some special case, one rank will recv 0 token, so add a token to make it run triton kernel. |
| 194 | + if is_triton_autotune_enabled(): |
| 195 | + _gemm_out_a = torch.empty((1, N), device=hidden_states.device, dtype=hidden_states.dtype) |
| 196 | + _silu_out = torch.empty((1, N // 2), device=hidden_states.device, dtype=hidden_states.dtype) |
| 197 | + silu_and_mul_fwd(_gemm_out_a.view(-1, N), _silu_out) |
| 198 | + _gemm_out_a, _silu_out = None, None |
| 199 | + |
189 | 200 | # normal combine |
190 | 201 | combined_x, _, event = buffer.combine( |
191 | 202 | gather_out, |
|
0 commit comments