Skip to content

Commit ec3cb03

Browse files
authored
【feature】add redundancy expert (#910)
1 parent a5265c4 commit ec3cb03

File tree

13 files changed

+795
-38
lines changed

13 files changed

+795
-38
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,9 @@ def _check_max_len_infer(self):
577577
logger.info("disable_check_max_len_infer is true")
578578
return
579579

580+
# 做一次 同步
581+
torch.distributed.barrier()
582+
580583
# 模拟最大长度进行 prefill,观察是否出现 OOM
581584
try:
582585
logger.info("begin check max_len infer")

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py

Lines changed: 114 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,14 @@
99
from lightllm.distributed import dist_group_manager
1010
from lightllm.common.fused_moe.topk_select import select_experts
1111
from lightllm.utils.envs_utils import get_deepep_num_max_dispatch_tokens_per_rank
12+
from lightllm.utils.envs_utils import get_redundancy_expert_ids, get_redundancy_expert_num
13+
from lightllm.utils.envs_utils import get_env_start_args
1214
from lightllm.common.quantization.triton_quant.fp8.fp8act_quant_kernel import (
1315
per_token_group_quant_fp8,
1416
tma_align_input_scale,
1517
)
1618
from lightllm.common.fused_moe.deepep_scatter_gather import ep_scatter, ep_gather
19+
from lightllm.common.basemodel.triton_kernel.redundancy_topk_ids_repair import redundancy_topk_ids_repair
1720
from lightllm.utils.log_utils import init_logger
1821

1922
logger = init_logger(__name__)
@@ -40,6 +43,7 @@ def __init__(
4043
) -> None:
4144
super().__init__()
4245

46+
self.layer_num = layer_num
4347
self.quant_method = quant_cfg.get_quant_method(layer_num, "fused_moe")
4448
self.quantized_weight = quant_cfg.quantized_weight
4549
if self.quant_method is not None:
@@ -60,15 +64,26 @@ def __init__(
6064

6165
global_world_size = get_global_world_size()
6266
self.global_rank_ = get_global_rank()
67+
self.redundancy_expert_num = get_redundancy_expert_num()
68+
self.redundancy_expert_ids = get_redundancy_expert_ids(layer_num)
69+
logger.info(
70+
f"global_rank {self.global_rank_} layerindex {layer_num} redundancy_expertids: {self.redundancy_expert_ids}"
71+
)
72+
self.redundancy_expert_ids_tensor = torch.tensor(self.redundancy_expert_ids, dtype=torch.int64, device="cuda")
73+
self.routed_expert_counter_tensor = torch.zeros((self.n_routed_experts,), dtype=torch.int64, device="cuda")
74+
self.total_expert_num_contain_redundancy = (
75+
self.n_routed_experts + self.redundancy_expert_num * global_world_size
76+
)
6377
assert self.n_routed_experts % global_world_size == 0
6478
self.ep_n_routed_experts = self.n_routed_experts // global_world_size
65-
self.experts_up_projs = [None] * self.ep_n_routed_experts
66-
self.experts_gate_projs = [None] * self.ep_n_routed_experts
67-
self.experts_up_proj_scales = [None] * self.ep_n_routed_experts
68-
self.experts_gate_proj_scales = [None] * self.ep_n_routed_experts
79+
ep_load_expert_num = self.ep_n_routed_experts + self.redundancy_expert_num
80+
self.experts_up_projs = [None] * ep_load_expert_num
81+
self.experts_gate_projs = [None] * ep_load_expert_num
82+
self.experts_up_proj_scales = [None] * ep_load_expert_num
83+
self.experts_gate_proj_scales = [None] * ep_load_expert_num
6984
self.e_score_correction_bias = None
70-
self.w2_list = [None] * self.ep_n_routed_experts
71-
self.w2_scale_list = [None] * self.ep_n_routed_experts
85+
self.w2_list = [None] * ep_load_expert_num
86+
self.w2_scale_list = [None] * ep_load_expert_num
7287
self.scoring_func = network_config["scoring_func"]
7388
self.w1 = [None, None] # weight, weight_scale
7489
self.w2 = [None, None] # weight, weight_scale
@@ -84,6 +99,9 @@ def __init__(
8499
self.lock = threading.Lock()
85100
# init buffer
86101

102+
# auto update redundancy expert vars
103+
self.auto_update_redundancy_expert: bool = get_env_start_args().auto_update_redundancy_expert
104+
87105
def experts(
88106
self,
89107
input_tensor,
@@ -106,6 +124,17 @@ def experts(
106124
num_expert_group=num_expert_group,
107125
scoring_func=self.scoring_func,
108126
)
127+
128+
if self.redundancy_expert_num > 0:
129+
redundancy_topk_ids_repair(
130+
topk_ids=topk_ids,
131+
redundancy_expert_ids=self.redundancy_expert_ids_tensor,
132+
ep_expert_num=self.ep_n_routed_experts,
133+
global_rank=self.global_rank_,
134+
expert_counter=self.routed_expert_counter_tensor,
135+
enable_counter=self.auto_update_redundancy_expert,
136+
)
137+
109138
w1, w1_scale = self.w1
110139
w2, w2_scale = self.w2
111140
return fused_experts_impl(
@@ -114,7 +143,7 @@ def experts(
114143
w2=w2,
115144
topk_weights=topk_weights,
116145
topk_idx=topk_ids.to(torch.long),
117-
num_experts=self.n_routed_experts, # number of all experts
146+
num_experts=self.total_expert_num_contain_redundancy, # number of all experts contain redundancy
118147
buffer=dist_group_manager.ep_buffer,
119148
is_prefill=is_prefill,
120149
use_fp8_w8a8=self.use_fp8_w8a8,
@@ -142,13 +171,24 @@ def low_latency_dispatch(
142171
num_expert_group=self.n_group,
143172
scoring_func=self.scoring_func,
144173
)
174+
175+
if self.redundancy_expert_num > 0:
176+
redundancy_topk_ids_repair(
177+
topk_ids=topk_idx,
178+
redundancy_expert_ids=self.redundancy_expert_ids_tensor,
179+
ep_expert_num=self.ep_n_routed_experts,
180+
global_rank=self.global_rank_,
181+
expert_counter=self.routed_expert_counter_tensor,
182+
enable_counter=self.auto_update_redundancy_expert,
183+
)
184+
145185
topk_idx = topk_idx.to(torch.long)
146186
num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank()
147187
recv_x, masked_m, handle, event, hook = dist_group_manager.ep_buffer.low_latency_dispatch(
148188
hidden_states,
149189
topk_idx,
150190
num_max_dispatch_tokens_per_rank,
151-
self.n_routed_experts,
191+
self.total_expert_num_contain_redundancy,
152192
use_fp8=self.use_fp8_w8a8,
153193
async_finish=False,
154194
return_recv_hook=True,
@@ -171,6 +211,15 @@ def select_experts_and_quant_input(
171211
num_expert_group=self.n_group,
172212
scoring_func=self.scoring_func,
173213
)
214+
if self.redundancy_expert_num > 0:
215+
redundancy_topk_ids_repair(
216+
topk_ids=topk_idx,
217+
redundancy_expert_ids=self.redundancy_expert_ids_tensor,
218+
ep_expert_num=self.ep_n_routed_experts,
219+
global_rank=self.global_rank_,
220+
expert_counter=self.routed_expert_counter_tensor,
221+
enable_counter=self.auto_update_redundancy_expert,
222+
)
174223
M, K = hidden_states.shape
175224
w1, w1_scale = self.w1
176225
block_size_k = 0
@@ -190,7 +239,6 @@ def dispatch(
190239
overlap_event: Optional[Any] = None,
191240
):
192241
buffer = dist_group_manager.ep_buffer
193-
num_experts = self.n_routed_experts
194242
# get_dispatch_layout
195243
(
196244
num_tokens_per_rank,
@@ -199,7 +247,11 @@ def dispatch(
199247
is_token_in_rank,
200248
previous_event,
201249
) = buffer.get_dispatch_layout(
202-
topk_idx, num_experts, previous_event=overlap_event, async_finish=True, allocate_on_comm_stream=True
250+
topk_idx,
251+
self.total_expert_num_contain_redundancy,
252+
previous_event=overlap_event,
253+
async_finish=True,
254+
allocate_on_comm_stream=True,
203255
)
204256
recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = buffer.dispatch(
205257
qinput_tensor,
@@ -342,16 +394,18 @@ def _fuse(self):
342394
and None not in self.experts_gate_projs
343395
and None not in self.w2_list
344396
):
345-
w1_list = []
346-
for i_experts in range(self.ep_n_routed_experts):
347-
expert_gate_up_proj = torch.cat(
348-
[self.experts_gate_projs[i_experts], self.experts_up_projs[i_experts]], dim=0
349-
)
350-
expert_gate_up_proj = expert_gate_up_proj
351-
w1_list.append(expert_gate_up_proj)
352-
353-
inter_shape, hidden_size = w1_list[0].shape[0], w1_list[0].shape[1]
354-
w1 = torch._utils._flatten_dense_tensors(w1_list).view(len(w1_list), inter_shape, hidden_size)
397+
gate_out_dim, gate_in_dim = self.experts_gate_projs[0].shape
398+
up_out_dim, up_in_dim = self.experts_up_projs[0].shape
399+
assert gate_in_dim == up_in_dim
400+
dtype = self.experts_gate_projs[0].dtype
401+
total_expert_num = self.ep_n_routed_experts + self.redundancy_expert_num
402+
403+
w1 = torch.empty((total_expert_num, gate_out_dim + up_out_dim, gate_in_dim), dtype=dtype, device="cpu")
404+
405+
for i_experts in range(self.ep_n_routed_experts + self.redundancy_expert_num):
406+
w1[i_experts, 0:gate_out_dim:, :] = self.experts_gate_projs[i_experts]
407+
w1[i_experts, gate_out_dim:, :] = self.experts_up_projs[i_experts]
408+
355409
inter_shape, hidden_size = self.w2_list[0].shape[0], self.w2_list[0].shape[1]
356410
w2 = torch._utils._flatten_dense_tensors(self.w2_list).view(len(self.w2_list), inter_shape, hidden_size)
357411
if not self.quantized_weight and self.quant_method is not None:
@@ -372,17 +426,20 @@ def _fuse_weight_scale(self):
372426
and None not in self.experts_gate_proj_scales
373427
and None not in self.w2_scale_list
374428
):
375-
w1_scale_list = []
376-
for i_experts in range(self.ep_n_routed_experts):
377-
expert_gate_up_proj_scale = torch.cat(
378-
[self.experts_gate_proj_scales[i_experts], self.experts_up_proj_scales[i_experts]], dim=0
379-
)
380-
w1_scale_list.append(expert_gate_up_proj_scale)
381-
382-
inter_shape, hidden_size = w1_scale_list[0].shape[0], w1_scale_list[0].shape[1]
383-
w1_scale = torch._utils._flatten_dense_tensors(w1_scale_list).view(
384-
len(w1_scale_list), inter_shape, hidden_size
429+
gate_out_dim, gate_in_dim = self.experts_gate_proj_scales[0].shape
430+
up_out_dim, up_in_dim = self.experts_up_proj_scales[0].shape
431+
assert gate_in_dim == up_in_dim
432+
dtype = self.experts_gate_proj_scales[0].dtype
433+
total_expert_num = self.ep_n_routed_experts + self.redundancy_expert_num
434+
435+
w1_scale = torch.empty(
436+
(total_expert_num, gate_out_dim + up_out_dim, gate_in_dim), dtype=dtype, device="cpu"
385437
)
438+
439+
for i_experts in range(self.ep_n_routed_experts + self.redundancy_expert_num):
440+
w1_scale[i_experts, 0:gate_out_dim:, :] = self.experts_gate_proj_scales[i_experts]
441+
w1_scale[i_experts, gate_out_dim:, :] = self.experts_up_proj_scales[i_experts]
442+
386443
inter_shape, hidden_size = self.w2_scale_list[0].shape[0], self.w2_scale_list[0].shape[1]
387444
w2_scale = torch._utils._flatten_dense_tensors(self.w2_scale_list).view(
388445
len(self.w2_scale_list), inter_shape, hidden_size
@@ -411,7 +468,20 @@ def load_hf_weights(self, weights):
411468
if w2_weight in weights:
412469
self.w2_list[i_experts_ep] = weights[w2_weight]
413470

414-
if self.quant_method is not None:
471+
# Load weight parameters for redundant experts
472+
for i, redundant_expert_id in enumerate(self.redundancy_expert_ids):
473+
i_experts = redundant_expert_id
474+
w1_weight = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.weight"
475+
w2_weight = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.weight"
476+
w3_weight = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.weight"
477+
if w1_weight in weights:
478+
self.experts_gate_projs[n_expert_ep + i] = weights[w1_weight]
479+
if w3_weight in weights:
480+
self.experts_up_projs[n_expert_ep + i] = weights[w3_weight]
481+
if w2_weight in weights:
482+
self.w2_list[n_expert_ep + i] = weights[w2_weight]
483+
484+
if self.quantized_weight:
415485
self._load_weight_scale(weights)
416486
self._fuse()
417487

@@ -430,6 +500,19 @@ def _load_weight_scale(self, weights: Dict[str, torch.Tensor]) -> None:
430500
if w2_scale in weights:
431501
self.w2_scale_list[i_experts_ep] = weights[w2_scale]
432502

503+
# Load scale parameters for redundant experts
504+
for i, redundant_expert_id in enumerate(self.redundancy_expert_ids):
505+
i_experts = redundant_expert_id
506+
w1_scale = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.{self.weight_scale_suffix}"
507+
w2_scale = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.{self.weight_scale_suffix}"
508+
w3_scale = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.{self.weight_scale_suffix}"
509+
if w1_scale in weights:
510+
self.experts_gate_proj_scales[n_expert_ep + i] = weights[w1_scale]
511+
if w3_scale in weights:
512+
self.experts_up_proj_scales[n_expert_ep + i] = weights[w3_scale]
513+
if w2_scale in weights:
514+
self.w2_scale_list[n_expert_ep + i] = weights[w2_scale]
515+
433516
def _cuda(self, cpu_tensor):
434517
device_id = get_current_device_id()
435518
if self.quantized_weight:

0 commit comments

Comments
 (0)