1717# See the License for the specific language governing permissions and
1818# limitations under the License.
1919
20+ import os
2021import torch
2122from lightllm .common .vllm_kernel import _custom_ops as ops
2223from typing import Callable , List , Optional , Tuple
2324
25+ use_cuda_grouped_topk = os .environ .get ("GROUPED_TOPK_CUDA" , "false" ).lower ()
2426
2527def fused_topk (
2628 hidden_states : torch .Tensor ,
@@ -61,7 +63,7 @@ def grouped_topk(
6163 topk_group : int = 0 ,
6264 scoring_func : str = "softmax" ,
6365):
64-
66+
6567 assert hidden_states .shape [0 ] == gating_output .shape [0 ], "Number of tokens mismatch"
6668 if scoring_func == "sigmoid" :
6769 scores = torch .sigmoid (gating_output )
@@ -89,6 +91,43 @@ def grouped_topk(
8991
9092 return topk_weights .to (torch .float32 ), topk_ids .to (torch .int32 )
9193
94+ # This is used by the Deepseek-V2 model
95+ def grouped_topk_cuda (
96+ hidden_states : torch .Tensor ,
97+ gating_output : torch .Tensor ,
98+ correction_bias : torch .Tensor ,
99+ topk : int ,
100+ renormalize : bool ,
101+ num_expert_group : int = 0 ,
102+ topk_group : int = 0 ,
103+ scoring_func : str = "softmax" ,
104+ ):
105+
106+ assert hidden_states .shape [0 ] == gating_output .shape [0 ], "Number of tokens mismatch"
107+ num_tokens = gating_output .shape [0 ]
108+ num_experts = gating_output .shape [- 1 ]
109+ topk_weights = torch .empty (num_tokens , topk , device = hidden_states .device , dtype = torch .float32 )
110+ topk_indices = torch .empty (num_tokens , topk , device = hidden_states .device , dtype = torch .int32 )
111+ token_expert_indices = torch .empty (num_tokens , topk_group , device = hidden_states .device , dtype = torch .int32 )
112+ group_scores = torch .empty (num_tokens , num_expert_group , device = hidden_states .device , dtype = torch .float32 )
113+ if correction_bias is None :
114+ correction_bias = torch .zeros_like (gating_output ,dtype = torch .float32 )
115+ ops .grouped_topk (
116+ topk_weights ,
117+ correction_bias ,
118+ topk_indices ,
119+ token_expert_indices ,
120+ gating_output .float (),
121+ num_expert_group ,
122+ topk_group ,
123+ topk ,
124+ renormalize ,
125+ scoring_func ,
126+ group_scores
127+ )
128+
129+ return topk_weights , topk_indices
130+
92131
93132def select_experts (
94133 hidden_states : torch .Tensor ,
@@ -103,21 +142,33 @@ def select_experts(
103142 custom_routing_function : Optional [Callable ] = None ,
104143):
105144 from lightllm .common .fused_moe .topk_select import fused_topk , grouped_topk
106-
107145 # DeekSeekv2 uses grouped_top_k
108146 if use_grouped_topk :
109147 assert topk_group is not None
110148 assert num_expert_group is not None
111- topk_weights , topk_ids = grouped_topk (
112- hidden_states = hidden_states ,
113- gating_output = router_logits ,
114- correction_bias = correction_bias ,
115- topk = top_k ,
116- renormalize = renormalize ,
117- num_expert_group = num_expert_group ,
118- topk_group = topk_group ,
119- scoring_func = scoring_func ,
120- )
149+ if use_cuda_grouped_topk == "true" :
150+ from lightllm .common .vllm_kernel import _custom_ops as ops
151+ topk_weights , topk_ids = grouped_topk_cuda (
152+ hidden_states = hidden_states ,
153+ gating_output = router_logits ,
154+ correction_bias = correction_bias ,
155+ topk = top_k ,
156+ renormalize = renormalize ,
157+ num_expert_group = num_expert_group ,
158+ topk_group = topk_group ,
159+ scoring_func = scoring_func ,
160+ )
161+ else :
162+ topk_weights , topk_ids = grouped_topk (
163+ hidden_states = hidden_states ,
164+ gating_output = router_logits ,
165+ correction_bias = correction_bias ,
166+ topk = top_k ,
167+ renormalize = renormalize ,
168+ num_expert_group = num_expert_group ,
169+ topk_group = topk_group ,
170+ scoring_func = scoring_func ,
171+ )
121172 elif custom_routing_function is None :
122173 topk_weights , topk_ids = fused_topk (
123174 hidden_states = hidden_states , gating_output = router_logits , topk = top_k , renormalize = renormalize
0 commit comments