Skip to content

Commit 0a900c8

Browse files
author
sangchengmeng
committed
add grouped_topk_cuda
1 parent 743ddc3 commit 0a900c8

File tree

2 files changed

+80
-12
lines changed

2 files changed

+80
-12
lines changed

lightllm/common/fused_moe/topk_select.py

Lines changed: 63 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@
1717
# See the License for the specific language governing permissions and
1818
# limitations under the License.
1919

20+
import os
2021
import torch
2122
from lightllm.common.vllm_kernel import _custom_ops as ops
2223
from typing import Callable, List, Optional, Tuple
2324

25+
use_cuda_grouped_topk = os.environ.get("GROUPED_TOPK_CUDA", "false").lower()
2426

2527
def fused_topk(
2628
hidden_states: torch.Tensor,
@@ -61,7 +63,7 @@ def grouped_topk(
6163
topk_group: int = 0,
6264
scoring_func: str = "softmax",
6365
):
64-
66+
6567
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
6668
if scoring_func == "sigmoid":
6769
scores = torch.sigmoid(gating_output)
@@ -89,6 +91,43 @@ def grouped_topk(
8991

9092
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
9193

94+
# This is used by the Deepseek-V2 model
95+
def grouped_topk_cuda(
96+
hidden_states: torch.Tensor,
97+
gating_output: torch.Tensor,
98+
correction_bias: torch.Tensor,
99+
topk: int,
100+
renormalize: bool,
101+
num_expert_group: int = 0,
102+
topk_group: int = 0,
103+
scoring_func: str = "softmax",
104+
):
105+
106+
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
107+
num_tokens = gating_output.shape[0]
108+
num_experts = gating_output.shape[-1]
109+
topk_weights = torch.empty(num_tokens, topk, device=hidden_states.device, dtype=torch.float32)
110+
topk_indices = torch.empty(num_tokens, topk, device=hidden_states.device, dtype=torch.int32)
111+
token_expert_indices = torch.empty(num_tokens, topk_group, device=hidden_states.device, dtype=torch.int32)
112+
group_scores = torch.empty(num_tokens, num_expert_group, device=hidden_states.device, dtype=torch.float32)
113+
if correction_bias is None:
114+
correction_bias = torch.zeros_like(gating_output,dtype=torch.float32)
115+
ops.grouped_topk(
116+
topk_weights,
117+
correction_bias,
118+
topk_indices,
119+
token_expert_indices,
120+
gating_output.float(),
121+
num_expert_group,
122+
topk_group,
123+
topk,
124+
renormalize,
125+
scoring_func,
126+
group_scores
127+
)
128+
129+
return topk_weights, topk_indices
130+
92131

93132
def select_experts(
94133
hidden_states: torch.Tensor,
@@ -103,21 +142,33 @@ def select_experts(
103142
custom_routing_function: Optional[Callable] = None,
104143
):
105144
from lightllm.common.fused_moe.topk_select import fused_topk, grouped_topk
106-
107145
# DeekSeekv2 uses grouped_top_k
108146
if use_grouped_topk:
109147
assert topk_group is not None
110148
assert num_expert_group is not None
111-
topk_weights, topk_ids = grouped_topk(
112-
hidden_states=hidden_states,
113-
gating_output=router_logits,
114-
correction_bias=correction_bias,
115-
topk=top_k,
116-
renormalize=renormalize,
117-
num_expert_group=num_expert_group,
118-
topk_group=topk_group,
119-
scoring_func=scoring_func,
120-
)
149+
if use_cuda_grouped_topk == "true":
150+
from lightllm.common.vllm_kernel import _custom_ops as ops
151+
topk_weights, topk_ids = grouped_topk_cuda(
152+
hidden_states=hidden_states,
153+
gating_output=router_logits,
154+
correction_bias=correction_bias,
155+
topk=top_k,
156+
renormalize=renormalize,
157+
num_expert_group=num_expert_group,
158+
topk_group=topk_group,
159+
scoring_func=scoring_func,
160+
)
161+
else:
162+
topk_weights, topk_ids = grouped_topk(
163+
hidden_states=hidden_states,
164+
gating_output=router_logits,
165+
correction_bias=correction_bias,
166+
topk=top_k,
167+
renormalize=renormalize,
168+
num_expert_group=num_expert_group,
169+
topk_group=topk_group,
170+
scoring_func=scoring_func,
171+
)
121172
elif custom_routing_function is None:
122173
topk_weights, topk_ids = fused_topk(
123174
hidden_states=hidden_states, gating_output=router_logits, topk=top_k, renormalize=renormalize

lightllm/common/vllm_kernel/_ops.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -760,6 +760,23 @@ def topk_softmax(
760760
) -> None:
761761
torch.ops.vllm_moe.topk_softmax(topk_weights, topk_ids, token_expert_indicies, gating_output)
762762

763+
def grouped_topk(
764+
topk_weights: torch.Tensor,
765+
correction_bias: torch.Tensor,
766+
topk_indices: torch.Tensor,
767+
group_indices: torch.Tensor,
768+
gating_output: torch.Tensor,
769+
num_expert_group: int,
770+
topk_group: int,
771+
topk: int,
772+
renormalize: bool,
773+
scoring_func: str,
774+
group_scores: torch.Tensor = None
775+
) -> None:
776+
torch.ops.vllm_moe.grouped_topk(
777+
topk_weights, correction_bias, topk_indices, group_indices, gating_output, num_expert_group,
778+
topk_group, topk, renormalize, scoring_func, group_scores
779+
)
763780

764781
def reshape_and_cache(
765782
key: torch.Tensor,

0 commit comments

Comments
 (0)