Skip to content

Commit 9876ed1

Browse files
authored
add grouped_topk_cuda (#723)
1 parent 1980781 commit 9876ed1

File tree

2 files changed

+90
-10
lines changed

2 files changed

+90
-10
lines changed

lightllm/common/fused_moe/topk_select.py

Lines changed: 62 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,13 @@
1717
# See the License for the specific language governing permissions and
1818
# limitations under the License.
1919

20+
import os
2021
import torch
2122
from lightllm.common.vllm_kernel import _custom_ops as ops
2223
from typing import Callable, List, Optional, Tuple
2324

25+
use_cuda_grouped_topk = os.getenv("LIGHTLLM_CUDA_GROUPED_TOPK", "False").upper() in ["ON", "TRUE", "1"]
26+
2427

2528
def fused_topk(
2629
hidden_states: torch.Tensor,
@@ -90,6 +93,43 @@ def grouped_topk(
9093
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
9194

9295

96+
# This is used by the Deepseek-V2 model
97+
def cuda_grouped_topk(
98+
hidden_states: torch.Tensor,
99+
gating_output: torch.Tensor,
100+
correction_bias: torch.Tensor,
101+
topk: int,
102+
renormalize: bool,
103+
num_expert_group: int = 0,
104+
topk_group: int = 0,
105+
scoring_func: str = "softmax",
106+
):
107+
108+
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
109+
num_tokens = gating_output.shape[0]
110+
topk_weights = torch.empty(num_tokens, topk, device=hidden_states.device, dtype=torch.float32)
111+
topk_indices = torch.empty(num_tokens, topk, device=hidden_states.device, dtype=torch.int32)
112+
token_expert_indices = torch.empty(num_tokens, topk_group, device=hidden_states.device, dtype=torch.int32)
113+
group_scores = torch.empty(num_tokens, num_expert_group, device=hidden_states.device, dtype=torch.float32)
114+
if correction_bias is None:
115+
correction_bias = torch.zeros_like(gating_output, dtype=torch.float32)
116+
ops.grouped_topk(
117+
topk_weights,
118+
correction_bias,
119+
topk_indices,
120+
token_expert_indices,
121+
gating_output.float(),
122+
num_expert_group,
123+
topk_group,
124+
topk,
125+
renormalize,
126+
scoring_func,
127+
group_scores,
128+
)
129+
130+
return topk_weights, topk_indices
131+
132+
93133
def select_experts(
94134
hidden_states: torch.Tensor,
95135
router_logits: torch.Tensor,
@@ -109,16 +149,28 @@ def select_experts(
109149
if use_grouped_topk:
110150
assert topk_group is not None
111151
assert num_expert_group is not None
112-
topk_weights, topk_ids = triton_grouped_topk(
113-
hidden_states=hidden_states,
114-
gating_output=router_logits,
115-
correction_bias=correction_bias,
116-
topk=top_k,
117-
renormalize=renormalize,
118-
num_expert_group=num_expert_group,
119-
topk_group=topk_group,
120-
scoring_func=scoring_func,
121-
)
152+
if use_cuda_grouped_topk:
153+
topk_weights, topk_ids = cuda_grouped_topk(
154+
hidden_states=hidden_states,
155+
gating_output=router_logits,
156+
correction_bias=correction_bias,
157+
topk=top_k,
158+
renormalize=renormalize,
159+
num_expert_group=num_expert_group,
160+
topk_group=topk_group,
161+
scoring_func=scoring_func,
162+
)
163+
else:
164+
topk_weights, topk_ids = triton_grouped_topk(
165+
hidden_states=hidden_states,
166+
gating_output=router_logits,
167+
correction_bias=correction_bias,
168+
topk=top_k,
169+
renormalize=renormalize,
170+
num_expert_group=num_expert_group,
171+
topk_group=topk_group,
172+
scoring_func=scoring_func,
173+
)
122174
elif custom_routing_function is None:
123175
topk_weights, topk_ids = fused_topk(
124176
hidden_states=hidden_states, gating_output=router_logits, topk=top_k, renormalize=renormalize

lightllm/common/vllm_kernel/_ops.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -761,6 +761,34 @@ def topk_softmax(
761761
torch.ops.vllm_moe.topk_softmax(topk_weights, topk_ids, token_expert_indicies, gating_output)
762762

763763

764+
def grouped_topk(
765+
topk_weights: torch.Tensor,
766+
correction_bias: torch.Tensor,
767+
topk_indices: torch.Tensor,
768+
group_indices: torch.Tensor,
769+
gating_output: torch.Tensor,
770+
num_expert_group: int,
771+
topk_group: int,
772+
topk: int,
773+
renormalize: bool,
774+
scoring_func: str,
775+
group_scores: torch.Tensor = None,
776+
) -> None:
777+
torch.ops.vllm_moe.grouped_topk(
778+
topk_weights,
779+
correction_bias,
780+
topk_indices,
781+
group_indices,
782+
gating_output,
783+
num_expert_group,
784+
topk_group,
785+
topk,
786+
renormalize,
787+
scoring_func,
788+
group_scores,
789+
)
790+
791+
764792
def reshape_and_cache(
765793
key: torch.Tensor,
766794
value: torch.Tensor,

0 commit comments

Comments
 (0)