1717# See the License for the specific language governing permissions and
1818# limitations under the License.
1919
20+ import os
2021import torch
2122from lightllm .common .vllm_kernel import _custom_ops as ops
2223from typing import Callable , List , Optional , Tuple
2324
25+ use_cuda_grouped_topk = os .getenv ("LIGHTLLM_CUDA_GROUPED_TOPK" , "False" ).upper () in ["ON" , "TRUE" , "1" ]
26+
2427
2528def fused_topk (
2629 hidden_states : torch .Tensor ,
@@ -90,6 +93,43 @@ def grouped_topk(
9093 return topk_weights .to (torch .float32 ), topk_ids .to (torch .int32 )
9194
9295
96+ # This is used by the Deepseek-V2 model
97+ def cuda_grouped_topk (
98+ hidden_states : torch .Tensor ,
99+ gating_output : torch .Tensor ,
100+ correction_bias : torch .Tensor ,
101+ topk : int ,
102+ renormalize : bool ,
103+ num_expert_group : int = 0 ,
104+ topk_group : int = 0 ,
105+ scoring_func : str = "softmax" ,
106+ ):
107+
108+ assert hidden_states .shape [0 ] == gating_output .shape [0 ], "Number of tokens mismatch"
109+ num_tokens = gating_output .shape [0 ]
110+ topk_weights = torch .empty (num_tokens , topk , device = hidden_states .device , dtype = torch .float32 )
111+ topk_indices = torch .empty (num_tokens , topk , device = hidden_states .device , dtype = torch .int32 )
112+ token_expert_indices = torch .empty (num_tokens , topk_group , device = hidden_states .device , dtype = torch .int32 )
113+ group_scores = torch .empty (num_tokens , num_expert_group , device = hidden_states .device , dtype = torch .float32 )
114+ if correction_bias is None :
115+ correction_bias = torch .zeros_like (gating_output , dtype = torch .float32 )
116+ ops .grouped_topk (
117+ topk_weights ,
118+ correction_bias ,
119+ topk_indices ,
120+ token_expert_indices ,
121+ gating_output .float (),
122+ num_expert_group ,
123+ topk_group ,
124+ topk ,
125+ renormalize ,
126+ scoring_func ,
127+ group_scores ,
128+ )
129+
130+ return topk_weights , topk_indices
131+
132+
93133def select_experts (
94134 hidden_states : torch .Tensor ,
95135 router_logits : torch .Tensor ,
@@ -109,16 +149,28 @@ def select_experts(
109149 if use_grouped_topk :
110150 assert topk_group is not None
111151 assert num_expert_group is not None
112- topk_weights , topk_ids = triton_grouped_topk (
113- hidden_states = hidden_states ,
114- gating_output = router_logits ,
115- correction_bias = correction_bias ,
116- topk = top_k ,
117- renormalize = renormalize ,
118- num_expert_group = num_expert_group ,
119- topk_group = topk_group ,
120- scoring_func = scoring_func ,
121- )
152+ if use_cuda_grouped_topk :
153+ topk_weights , topk_ids = cuda_grouped_topk (
154+ hidden_states = hidden_states ,
155+ gating_output = router_logits ,
156+ correction_bias = correction_bias ,
157+ topk = top_k ,
158+ renormalize = renormalize ,
159+ num_expert_group = num_expert_group ,
160+ topk_group = topk_group ,
161+ scoring_func = scoring_func ,
162+ )
163+ else :
164+ topk_weights , topk_ids = triton_grouped_topk (
165+ hidden_states = hidden_states ,
166+ gating_output = router_logits ,
167+ correction_bias = correction_bias ,
168+ topk = top_k ,
169+ renormalize = renormalize ,
170+ num_expert_group = num_expert_group ,
171+ topk_group = topk_group ,
172+ scoring_func = scoring_func ,
173+ )
122174 elif custom_routing_function is None :
123175 topk_weights , topk_ids = fused_topk (
124176 hidden_states = hidden_states , gating_output = router_logits , topk = top_k , renormalize = renormalize
0 commit comments