Skip to content

Commit c2ad0a9

Browse files
[Cherry-Pick][Feature] Support redundant expert for eplb (#5918) (#5923)
* [Cherry-Pick] Support redundant expert for eplb * [Cherry-Pick] Support redundant expert for eplb * [Cherry-Pick] Support redundant expert for eplb * update * update
1 parent f065981 commit c2ad0a9

File tree

10 files changed

+44
-16
lines changed

10 files changed

+44
-16
lines changed

custom_ops/gpu_ops/moe/ep_moe_expert_dispatch.cu

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@
4343
__VA_ARGS__ \
4444
break; \
4545
} \
46+
case 7: { \
47+
constexpr size_t NUM_EXPERTS_PER_RANK = 7; \
48+
__VA_ARGS__ \
49+
break; \
50+
} \
4651
case 8: { \
4752
constexpr size_t NUM_EXPERTS_PER_RANK = 8; \
4853
__VA_ARGS__ \
@@ -53,11 +58,26 @@
5358
__VA_ARGS__ \
5459
break; \
5560
} \
61+
case 10: { \
62+
constexpr size_t NUM_EXPERTS_PER_RANK = 10; \
63+
__VA_ARGS__ \
64+
break; \
65+
} \
5666
case 16: { \
5767
constexpr size_t NUM_EXPERTS_PER_RANK = 16; \
5868
__VA_ARGS__ \
5969
break; \
6070
} \
71+
case 17: { \
72+
constexpr size_t NUM_EXPERTS_PER_RANK = 17; \
73+
__VA_ARGS__ \
74+
break; \
75+
} \
76+
case 20: { \
77+
constexpr size_t NUM_EXPERTS_PER_RANK = 20; \
78+
__VA_ARGS__ \
79+
break; \
80+
} \
6181
case 32: { \
6282
constexpr size_t NUM_EXPERTS_PER_RANK = 32; \
6383
__VA_ARGS__ \

fastdeploy/entrypoints/engine_client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -593,6 +593,7 @@ async def rearrange_experts(self, request_dict: dict):
593593
Returns:
594594
tuple: response body, status code
595595
"""
596+
content, status_code = None, HTTPStatus.OK
596597
eplb_config = self.config.eplb_config
597598
if not eplb_config.enable_eplb:
598599
content = {"code": 1, "msg": "redundant expert is disabled"}
@@ -695,6 +696,7 @@ async def get_per_expert_tokens_stats(self, request_dict: dict):
695696
Returns:
696697
tuple: response body, status code
697698
"""
699+
content, status_code = None, HTTPStatus.OK
698700
eplb_config = self.config.eplb_config
699701
if not eplb_config.enable_eplb:
700702
content = {"code": 1, "msg": "redundant expert is disabled"}

fastdeploy/model_executor/layers/backends/xpu/moe/ep.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor):
262262
moe_topk=self.top_k,
263263
apply_norm_weight=True, # apply_norm_weight
264264
enable_softmax_top_k_fused=False,
265-
redundant_ep_rank_num_plus_one=layer.fd_config.model_config.redundant_experts_num + 1,
265+
redundant_ep_rank_num_plus_one=layer.fd_config.eplb_config.redundant_experts_num + 1,
266266
)
267267
else:
268268
topk_idx, topk_weights = fastdeploy.model_executor.ops.xpu.moe_topk_select(

fastdeploy/model_executor/layers/moe/ep.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor):
449449
expert_id_to_ep_rank_array=expert_id_to_ep_rank_array,
450450
expert_in_rank_num_list=expert_in_rank_num_list,
451451
tokens_per_expert_stats_list=tokens_per_expert_stats_list,
452-
redundant_ep_rank_num_plus_one=layer.fd_config.model_config.redundant_experts_num + 1,
452+
redundant_ep_rank_num_plus_one=layer.fd_config.eplb_config.redundant_experts_num + 1,
453453
)
454454
else:
455455
topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_redundant_topk_select(
@@ -461,7 +461,7 @@ def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor):
461461
moe_topk=self.top_k,
462462
apply_norm_weight=True,
463463
enable_softmax_top_k_fused=False,
464-
redundant_ep_rank_num_plus_one=layer.fd_config.model_config.redundant_experts_num + 1,
464+
redundant_ep_rank_num_plus_one=layer.fd_config.eplb_config.redundant_experts_num + 1,
465465
)
466466
else:
467467
if layer.topk_method == "noaux_tc":

fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def init_ep(self, layer: nn.Layer) -> None:
8484
"num_max_dispatch_tokens_per_rank": layer.fd_config.model_config.num_max_dispatch_tokens_per_rank,
8585
"ep_size": layer.ep_size,
8686
"ep_rank": layer.ep_rank,
87-
"redundant_experts_num": layer.fd_config.model_config.redundant_experts_num,
87+
"redundant_experts_num": layer.fd_config.eplb_config.redundant_experts_num,
8888
"ep_group": layer.fd_config.parallel_config.ep_group,
8989
}
9090

fastdeploy/model_executor/layers/moe/moe.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -457,13 +457,18 @@ def load_experts_weight(
457457
"""
458458
logical_expert_ids = [
459459
i
460+
% (
461+
self.fd_config.model_config.moe_num_experts[0]
462+
if isinstance(self.fd_config.model_config.moe_num_experts, list)
463+
else self.fd_config.model_config.moe_num_experts
464+
)
460465
for i in range(
461466
self.expert_id_offset,
462467
self.expert_id_offset + self.num_local_experts,
463468
)
464469
]
465470
ep_rank_to_expert_id_list = [i for i in range(self.num_experts)]
466-
if self.redundant_table_manger is not None and is_rearrange is True:
471+
if self.redundant_table_manger is not None:
467472
(
468473
ep_rank_to_expert_id_list,
469474
expert_id_to_ep_rank_array,
@@ -477,18 +482,15 @@ def load_experts_weight(
477482
down_proj_weights = []
478483
if isinstance(state_dict, list):
479484
state_dict = dict(state_dict)
480-
is_ffn_merged = (
481-
up_gate_proj_expert_weight_key.format(logical_expert_ids[0] if is_rearrange else self.expert_id_offset)
482-
in state_dict
483-
)
485+
is_ffn_merged = up_gate_proj_expert_weight_key.format(logical_expert_ids[0]) in state_dict
484486
if is_ffn_merged:
485487
for expert_idx in logical_expert_ids:
486488
down_proj_expert_weight_key_name = down_proj_expert_weight_key.format(expert_idx)
487489
up_gate_proj_expert_weight_key_name = up_gate_proj_expert_weight_key.format(expert_idx)
488490
up_gate_proj_weights.append(
489491
get_tensor(
490492
(
491-
state_dict.pop(up_gate_proj_expert_weight_key_name)
493+
state_dict[up_gate_proj_expert_weight_key_name]
492494
if up_gate_proj_expert_weight_key_name in state_dict
493495
else up_gate_proj_expert_weight_key_name
494496
),
@@ -498,7 +500,7 @@ def load_experts_weight(
498500
down_proj_weights.append(
499501
get_tensor(
500502
(
501-
state_dict.pop(down_proj_expert_weight_key_name)
503+
state_dict[down_proj_expert_weight_key_name]
502504
if down_proj_expert_weight_key_name in state_dict
503505
else down_proj_expert_weight_key_name
504506
),

fastdeploy/model_executor/load_weight_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,8 +251,12 @@ def get_expert_ranges(fd_config):
251251
"mtp_block" if getattr(fd_config.speculative_config, "model_type", "main") == "mtp" else "layers"
252252
)
253253

254+
moe_num_experts = fd_config.model_config.moe_num_experts
255+
if isinstance(moe_num_experts, list):
256+
moe_num_experts = moe_num_experts[0]
254257
for i in range(fd_config.model_config.moe_layer_start_index, fd_config.model_config.num_hidden_layers):
255258
for j in get_expert_ranges(fd_config):
259+
j = j % moe_num_experts
256260
up_gate_proj_key = f"ernie.{prefix_layer_name}.{i}.mlp.experts.{j}.up_gate_proj.weight"
257261
down_proj_key = f"ernie.{prefix_layer_name}.{i}.mlp.experts.{j}.down_proj.weight"
258262

fastdeploy/model_executor/models/ernie4_5_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ def __init__(
372372
self.redundant_table_manger = RedundantExpertManger(
373373
n_routed_experts=fd_config.model_config.moe_num_experts,
374374
num_hidden_layers=fd_config.model_config.num_hidden_layers,
375-
redundant_experts_num=fd_config.model_config.redundant_experts_num,
375+
redundant_experts_num=fd_config.eplb_config.redundant_experts_num,
376376
ep_size=fd_config.parallel_config.expert_parallel_size,
377377
)
378378

fastdeploy/worker/experts_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def __init__(
4242
self.num_hidden_layers = num_hidden_layers
4343

4444
self.num_replicas = self.num_expert + self.redundant_experts_num
45-
self.num_nodes = max(ep_size // 8, 1)
45+
self.num_nodes = max(ep_size // 8, 8)
4646
self.num_gpus = ep_size
4747
self.num_groups = 1
4848

fastdeploy/worker/worker_process.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -892,16 +892,17 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
892892
parallel_config = ParallelConfig(vars(args))
893893
cache_config = CacheConfig(vars(args))
894894
scheduler_config = SchedulerConfig(vars(args))
895+
eplb_config = EPLBConfig(args.eplb_config)
895896

896897
parallel_config.tensor_parallel_rank = local_rank % parallel_config.tensor_parallel_size
897898
parallel_config.data_parallel_rank = local_rank // parallel_config.tensor_parallel_size
898899
# config for EP
899900
if parallel_config.expert_parallel_size > 1:
900901
expert_parallel_rank = int(local_rank % parallel_config.expert_parallel_size)
901902
if isinstance(model_config.moe_num_experts, list):
902-
num_experts = model_config.moe_num_experts[0]
903+
num_experts = model_config.moe_num_experts[0] + eplb_config.redundant_experts_num
903904
else:
904-
num_experts = model_config.moe_num_experts
905+
num_experts = model_config.moe_num_experts + eplb_config.redundant_experts_num
905906
num_experts_per_rank = num_experts // parallel_config.expert_parallel_size
906907
num_experts_start_offset = expert_parallel_rank * num_experts_per_rank
907908
max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
@@ -926,7 +927,6 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
926927
plas_attention_config = PlasAttentionConfig(args.plas_attention_config)
927928

928929
early_stop_config = EarlyStopConfig(args.early_stop_config)
929-
eplb_config = EPLBConfig(args.eplb_config)
930930

931931
structured_outputs_config: StructuredOutputsConfig = StructuredOutputsConfig(args=vars(args))
932932

0 commit comments

Comments
 (0)