Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 62 additions & 10 deletions lightllm/common/fused_moe/topk_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import torch
from lightllm.common.vllm_kernel import _custom_ops as ops
from typing import Callable, List, Optional, Tuple

use_cuda_grouped_topk = os.getenv("LIGHTLLM_CUDA_GROUPED_TOPK", "False").upper() in ["ON", "TRUE", "1"]


def fused_topk(
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -90,6 +93,43 @@ def grouped_topk(
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)


# This is used by the Deepseek-V2 model
def cuda_grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
correction_bias: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
scoring_func: str = "softmax",
):

assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
num_tokens = gating_output.shape[0]
topk_weights = torch.empty(num_tokens, topk, device=hidden_states.device, dtype=torch.float32)
topk_indices = torch.empty(num_tokens, topk, device=hidden_states.device, dtype=torch.int32)
token_expert_indices = torch.empty(num_tokens, topk_group, device=hidden_states.device, dtype=torch.int32)
group_scores = torch.empty(num_tokens, num_expert_group, device=hidden_states.device, dtype=torch.float32)
if correction_bias is None:
correction_bias = torch.zeros_like(gating_output, dtype=torch.float32)
ops.grouped_topk(
topk_weights,
correction_bias,
topk_indices,
token_expert_indices,
gating_output.float(),
num_expert_group,
topk_group,
topk,
renormalize,
scoring_func,
group_scores,
)

return topk_weights, topk_indices


def select_experts(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
Expand All @@ -109,16 +149,28 @@ def select_experts(
if use_grouped_topk:
assert topk_group is not None
assert num_expert_group is not None
topk_weights, topk_ids = triton_grouped_topk(
hidden_states=hidden_states,
gating_output=router_logits,
correction_bias=correction_bias,
topk=top_k,
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
)
if use_cuda_grouped_topk:
topk_weights, topk_ids = cuda_grouped_topk(
hidden_states=hidden_states,
gating_output=router_logits,
correction_bias=correction_bias,
topk=top_k,
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
)
else:
topk_weights, topk_ids = triton_grouped_topk(
hidden_states=hidden_states,
gating_output=router_logits,
correction_bias=correction_bias,
topk=top_k,
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
)
elif custom_routing_function is None:
topk_weights, topk_ids = fused_topk(
hidden_states=hidden_states, gating_output=router_logits, topk=top_k, renormalize=renormalize
Expand Down
28 changes: 28 additions & 0 deletions lightllm/common/vllm_kernel/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,34 @@ def topk_softmax(
torch.ops.vllm_moe.topk_softmax(topk_weights, topk_ids, token_expert_indicies, gating_output)


def grouped_topk(
topk_weights: torch.Tensor,
correction_bias: torch.Tensor,
topk_indices: torch.Tensor,
group_indices: torch.Tensor,
gating_output: torch.Tensor,
num_expert_group: int,
topk_group: int,
topk: int,
renormalize: bool,
scoring_func: str,
group_scores: torch.Tensor = None,
) -> None:
torch.ops.vllm_moe.grouped_topk(
topk_weights,
correction_bias,
topk_indices,
group_indices,
gating_output,
num_expert_group,
topk_group,
topk,
renormalize,
scoring_func,
group_scores,
)


def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
Expand Down