diff --git a/lightllm/common/fused_moe/topk_select.py b/lightllm/common/fused_moe/topk_select.py index ca8d22f48..4e50576a9 100644 --- a/lightllm/common/fused_moe/topk_select.py +++ b/lightllm/common/fused_moe/topk_select.py @@ -41,19 +41,13 @@ def fused_topk( topk_weights = torch.empty(M, topk, dtype=torch.float32, device=hidden_states.device) topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device) - token_expert_indicies = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device) sgl_ops.topk_softmax( topk_weights, topk_ids, - token_expert_indicies, gating_output.float(), # TODO(woosuk): Optimize this. + renormalize=renormalize, ) - del token_expert_indicies # Not used. Will be used in the future. - - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - return topk_weights, topk_ids