99from lightllm .distributed import dist_group_manager
1010from lightllm .common .fused_moe .topk_select import select_experts
1111from lightllm .utils .envs_utils import get_deepep_num_max_dispatch_tokens_per_rank
12+ from lightllm .utils .envs_utils import get_redundancy_expert_ids , get_redundancy_expert_num
13+ from lightllm .utils .envs_utils import get_env_start_args
1214from lightllm .common .quantization .triton_quant .fp8 .fp8act_quant_kernel import (
1315 per_token_group_quant_fp8 ,
1416 tma_align_input_scale ,
1517)
1618from lightllm .common .fused_moe .deepep_scatter_gather import ep_scatter , ep_gather
19+ from lightllm .common .basemodel .triton_kernel .redundancy_topk_ids_repair import redundancy_topk_ids_repair
1720from lightllm .utils .log_utils import init_logger
1821
1922logger = init_logger (__name__ )
@@ -40,6 +43,7 @@ def __init__(
4043 ) -> None :
4144 super ().__init__ ()
4245
46+ self .layer_num = layer_num
4347 self .quant_method = quant_cfg .get_quant_method (layer_num , "fused_moe" )
4448 self .quantized_weight = quant_cfg .quantized_weight
4549 if self .quant_method is not None :
@@ -60,15 +64,26 @@ def __init__(
6064
6165 global_world_size = get_global_world_size ()
6266 self .global_rank_ = get_global_rank ()
67+ self .redundancy_expert_num = get_redundancy_expert_num ()
68+ self .redundancy_expert_ids = get_redundancy_expert_ids (layer_num )
69+ logger .info (
70+ f"global_rank { self .global_rank_ } layerindex { layer_num } redundancy_expertids: { self .redundancy_expert_ids } "
71+ )
72+ self .redundancy_expert_ids_tensor = torch .tensor (self .redundancy_expert_ids , dtype = torch .int64 , device = "cuda" )
73+ self .routed_expert_counter_tensor = torch .zeros ((self .n_routed_experts ,), dtype = torch .int64 , device = "cuda" )
74+ self .total_expert_num_contain_redundancy = (
75+ self .n_routed_experts + self .redundancy_expert_num * global_world_size
76+ )
6377 assert self .n_routed_experts % global_world_size == 0
6478 self .ep_n_routed_experts = self .n_routed_experts // global_world_size
65- self .experts_up_projs = [None ] * self .ep_n_routed_experts
66- self .experts_gate_projs = [None ] * self .ep_n_routed_experts
67- self .experts_up_proj_scales = [None ] * self .ep_n_routed_experts
68- self .experts_gate_proj_scales = [None ] * self .ep_n_routed_experts
79+ ep_load_expert_num = self .ep_n_routed_experts + self .redundancy_expert_num
80+ self .experts_up_projs = [None ] * ep_load_expert_num
81+ self .experts_gate_projs = [None ] * ep_load_expert_num
82+ self .experts_up_proj_scales = [None ] * ep_load_expert_num
83+ self .experts_gate_proj_scales = [None ] * ep_load_expert_num
6984 self .e_score_correction_bias = None
70- self .w2_list = [None ] * self . ep_n_routed_experts
71- self .w2_scale_list = [None ] * self . ep_n_routed_experts
85+ self .w2_list = [None ] * ep_load_expert_num
86+ self .w2_scale_list = [None ] * ep_load_expert_num
7287 self .scoring_func = network_config ["scoring_func" ]
7388 self .w1 = [None , None ] # weight, weight_scale
7489 self .w2 = [None , None ] # weight, weight_scale
@@ -84,6 +99,9 @@ def __init__(
8499 self .lock = threading .Lock ()
85100 # init buffer
86101
102+ # auto update redundancy expert vars
103+ self .auto_update_redundancy_expert : bool = get_env_start_args ().auto_update_redundancy_expert
104+
87105 def experts (
88106 self ,
89107 input_tensor ,
@@ -106,6 +124,17 @@ def experts(
106124 num_expert_group = num_expert_group ,
107125 scoring_func = self .scoring_func ,
108126 )
127+
128+ if self .redundancy_expert_num > 0 :
129+ redundancy_topk_ids_repair (
130+ topk_ids = topk_ids ,
131+ redundancy_expert_ids = self .redundancy_expert_ids_tensor ,
132+ ep_expert_num = self .ep_n_routed_experts ,
133+ global_rank = self .global_rank_ ,
134+ expert_counter = self .routed_expert_counter_tensor ,
135+ enable_counter = self .auto_update_redundancy_expert ,
136+ )
137+
109138 w1 , w1_scale = self .w1
110139 w2 , w2_scale = self .w2
111140 return fused_experts_impl (
@@ -114,7 +143,7 @@ def experts(
114143 w2 = w2 ,
115144 topk_weights = topk_weights ,
116145 topk_idx = topk_ids .to (torch .long ),
117- num_experts = self .n_routed_experts , # number of all experts
146+ num_experts = self .total_expert_num_contain_redundancy , # number of all experts contain redundancy
118147 buffer = dist_group_manager .ep_buffer ,
119148 is_prefill = is_prefill ,
120149 use_fp8_w8a8 = self .use_fp8_w8a8 ,
@@ -142,13 +171,24 @@ def low_latency_dispatch(
142171 num_expert_group = self .n_group ,
143172 scoring_func = self .scoring_func ,
144173 )
174+
175+ if self .redundancy_expert_num > 0 :
176+ redundancy_topk_ids_repair (
177+ topk_ids = topk_idx ,
178+ redundancy_expert_ids = self .redundancy_expert_ids_tensor ,
179+ ep_expert_num = self .ep_n_routed_experts ,
180+ global_rank = self .global_rank_ ,
181+ expert_counter = self .routed_expert_counter_tensor ,
182+ enable_counter = self .auto_update_redundancy_expert ,
183+ )
184+
145185 topk_idx = topk_idx .to (torch .long )
146186 num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank ()
147187 recv_x , masked_m , handle , event , hook = dist_group_manager .ep_buffer .low_latency_dispatch (
148188 hidden_states ,
149189 topk_idx ,
150190 num_max_dispatch_tokens_per_rank ,
151- self .n_routed_experts ,
191+ self .total_expert_num_contain_redundancy ,
152192 use_fp8 = self .use_fp8_w8a8 ,
153193 async_finish = False ,
154194 return_recv_hook = True ,
@@ -171,6 +211,15 @@ def select_experts_and_quant_input(
171211 num_expert_group = self .n_group ,
172212 scoring_func = self .scoring_func ,
173213 )
214+ if self .redundancy_expert_num > 0 :
215+ redundancy_topk_ids_repair (
216+ topk_ids = topk_idx ,
217+ redundancy_expert_ids = self .redundancy_expert_ids_tensor ,
218+ ep_expert_num = self .ep_n_routed_experts ,
219+ global_rank = self .global_rank_ ,
220+ expert_counter = self .routed_expert_counter_tensor ,
221+ enable_counter = self .auto_update_redundancy_expert ,
222+ )
174223 M , K = hidden_states .shape
175224 w1 , w1_scale = self .w1
176225 block_size_k = 0
@@ -190,7 +239,6 @@ def dispatch(
190239 overlap_event : Optional [Any ] = None ,
191240 ):
192241 buffer = dist_group_manager .ep_buffer
193- num_experts = self .n_routed_experts
194242 # get_dispatch_layout
195243 (
196244 num_tokens_per_rank ,
@@ -199,7 +247,11 @@ def dispatch(
199247 is_token_in_rank ,
200248 previous_event ,
201249 ) = buffer .get_dispatch_layout (
202- topk_idx , num_experts , previous_event = overlap_event , async_finish = True , allocate_on_comm_stream = True
250+ topk_idx ,
251+ self .total_expert_num_contain_redundancy ,
252+ previous_event = overlap_event ,
253+ async_finish = True ,
254+ allocate_on_comm_stream = True ,
203255 )
204256 recv_x , recv_topk_idx , recv_topk_weights , num_recv_tokens_per_expert_list , handle , event = buffer .dispatch (
205257 qinput_tensor ,
@@ -342,16 +394,18 @@ def _fuse(self):
342394 and None not in self .experts_gate_projs
343395 and None not in self .w2_list
344396 ):
345- w1_list = []
346- for i_experts in range (self .ep_n_routed_experts ):
347- expert_gate_up_proj = torch .cat (
348- [self .experts_gate_projs [i_experts ], self .experts_up_projs [i_experts ]], dim = 0
349- )
350- expert_gate_up_proj = expert_gate_up_proj
351- w1_list .append (expert_gate_up_proj )
352-
353- inter_shape , hidden_size = w1_list [0 ].shape [0 ], w1_list [0 ].shape [1 ]
354- w1 = torch ._utils ._flatten_dense_tensors (w1_list ).view (len (w1_list ), inter_shape , hidden_size )
397+ gate_out_dim , gate_in_dim = self .experts_gate_projs [0 ].shape
398+ up_out_dim , up_in_dim = self .experts_up_projs [0 ].shape
399+ assert gate_in_dim == up_in_dim
400+ dtype = self .experts_gate_projs [0 ].dtype
401+ total_expert_num = self .ep_n_routed_experts + self .redundancy_expert_num
402+
403+ w1 = torch .empty ((total_expert_num , gate_out_dim + up_out_dim , gate_in_dim ), dtype = dtype , device = "cpu" )
404+
405+ for i_experts in range (self .ep_n_routed_experts + self .redundancy_expert_num ):
406+ w1 [i_experts , 0 :gate_out_dim :, :] = self .experts_gate_projs [i_experts ]
407+ w1 [i_experts , gate_out_dim :, :] = self .experts_up_projs [i_experts ]
408+
355409 inter_shape , hidden_size = self .w2_list [0 ].shape [0 ], self .w2_list [0 ].shape [1 ]
356410 w2 = torch ._utils ._flatten_dense_tensors (self .w2_list ).view (len (self .w2_list ), inter_shape , hidden_size )
357411 if not self .quantized_weight and self .quant_method is not None :
@@ -372,17 +426,20 @@ def _fuse_weight_scale(self):
372426 and None not in self .experts_gate_proj_scales
373427 and None not in self .w2_scale_list
374428 ):
375- w1_scale_list = []
376- for i_experts in range (self .ep_n_routed_experts ):
377- expert_gate_up_proj_scale = torch .cat (
378- [self .experts_gate_proj_scales [i_experts ], self .experts_up_proj_scales [i_experts ]], dim = 0
379- )
380- w1_scale_list .append (expert_gate_up_proj_scale )
381-
382- inter_shape , hidden_size = w1_scale_list [0 ].shape [0 ], w1_scale_list [0 ].shape [1 ]
383- w1_scale = torch ._utils ._flatten_dense_tensors (w1_scale_list ).view (
384- len (w1_scale_list ), inter_shape , hidden_size
429+ gate_out_dim , gate_in_dim = self .experts_gate_proj_scales [0 ].shape
430+ up_out_dim , up_in_dim = self .experts_up_proj_scales [0 ].shape
431+ assert gate_in_dim == up_in_dim
432+ dtype = self .experts_gate_proj_scales [0 ].dtype
433+ total_expert_num = self .ep_n_routed_experts + self .redundancy_expert_num
434+
435+ w1_scale = torch .empty (
436+ (total_expert_num , gate_out_dim + up_out_dim , gate_in_dim ), dtype = dtype , device = "cpu"
385437 )
438+
439+ for i_experts in range (self .ep_n_routed_experts + self .redundancy_expert_num ):
440+ w1_scale [i_experts , 0 :gate_out_dim :, :] = self .experts_gate_proj_scales [i_experts ]
441+ w1_scale [i_experts , gate_out_dim :, :] = self .experts_up_proj_scales [i_experts ]
442+
386443 inter_shape , hidden_size = self .w2_scale_list [0 ].shape [0 ], self .w2_scale_list [0 ].shape [1 ]
387444 w2_scale = torch ._utils ._flatten_dense_tensors (self .w2_scale_list ).view (
388445 len (self .w2_scale_list ), inter_shape , hidden_size
@@ -411,7 +468,20 @@ def load_hf_weights(self, weights):
411468 if w2_weight in weights :
412469 self .w2_list [i_experts_ep ] = weights [w2_weight ]
413470
414- if self .quant_method is not None :
471+ # Load weight parameters for redundant experts
472+ for i , redundant_expert_id in enumerate (self .redundancy_expert_ids ):
473+ i_experts = redundant_expert_id
474+ w1_weight = f"{ self .weight_prefix } .{ i_experts } .{ self .w1_weight_name } .weight"
475+ w2_weight = f"{ self .weight_prefix } .{ i_experts } .{ self .w2_weight_name } .weight"
476+ w3_weight = f"{ self .weight_prefix } .{ i_experts } .{ self .w3_weight_name } .weight"
477+ if w1_weight in weights :
478+ self .experts_gate_projs [n_expert_ep + i ] = weights [w1_weight ]
479+ if w3_weight in weights :
480+ self .experts_up_projs [n_expert_ep + i ] = weights [w3_weight ]
481+ if w2_weight in weights :
482+ self .w2_list [n_expert_ep + i ] = weights [w2_weight ]
483+
484+ if self .quantized_weight :
415485 self ._load_weight_scale (weights )
416486 self ._fuse ()
417487
@@ -430,6 +500,19 @@ def _load_weight_scale(self, weights: Dict[str, torch.Tensor]) -> None:
430500 if w2_scale in weights :
431501 self .w2_scale_list [i_experts_ep ] = weights [w2_scale ]
432502
503+ # Load scale parameters for redundant experts
504+ for i , redundant_expert_id in enumerate (self .redundancy_expert_ids ):
505+ i_experts = redundant_expert_id
506+ w1_scale = f"{ self .weight_prefix } .{ i_experts } .{ self .w1_weight_name } .{ self .weight_scale_suffix } "
507+ w2_scale = f"{ self .weight_prefix } .{ i_experts } .{ self .w2_weight_name } .{ self .weight_scale_suffix } "
508+ w3_scale = f"{ self .weight_prefix } .{ i_experts } .{ self .w3_weight_name } .{ self .weight_scale_suffix } "
509+ if w1_scale in weights :
510+ self .experts_gate_proj_scales [n_expert_ep + i ] = weights [w1_scale ]
511+ if w3_scale in weights :
512+ self .experts_up_proj_scales [n_expert_ep + i ] = weights [w3_scale ]
513+ if w2_scale in weights :
514+ self .w2_scale_list [n_expert_ep + i ] = weights [w2_scale ]
515+
433516 def _cuda (self , cpu_tensor ):
434517 device_id = get_current_device_id ()
435518 if self .quantized_weight :
0 commit comments