Skip to content

Commit 00a01ae

Browse files
[Feature] Support redundant expert for eplb (#5918)
* [BugFix] support redundant expert for eplb * support redundant expert for eplb * support redundant expert for eplb * update * fix ci eplb
1 parent e6cdea4 commit 00a01ae

File tree

11 files changed

+36
-17
lines changed

11 files changed

+36
-17
lines changed

custom_ops/gpu_ops/moe/ep_moe_expert_dispatch.cu

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@
4848
__VA_ARGS__ \
4949
break; \
5050
} \
51+
case 7: { \
52+
constexpr size_t NUM_EXPERTS_PER_RANK = 7; \
53+
__VA_ARGS__ \
54+
break; \
55+
} \
5156
case 8: { \
5257
constexpr size_t NUM_EXPERTS_PER_RANK = 8; \
5358
__VA_ARGS__ \
@@ -68,6 +73,11 @@
6873
__VA_ARGS__ \
6974
break; \
7075
} \
76+
case 17: { \
77+
constexpr size_t NUM_EXPERTS_PER_RANK = 17; \
78+
__VA_ARGS__ \
79+
break; \
80+
} \
7181
case 20: { \
7282
constexpr size_t NUM_EXPERTS_PER_RANK = 20; \
7383
__VA_ARGS__ \

fastdeploy/entrypoints/engine_client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,7 @@ async def rearrange_experts(self, request_dict: dict):
607607
Returns:
608608
tuple: response body, status code
609609
"""
610+
content, status_code = None, HTTPStatus.OK
610611
eplb_config = self.fd_config.eplb_config
611612
if not eplb_config.enable_eplb:
612613
content = {"code": 1, "msg": "redundant expert is disabled"}
@@ -709,6 +710,7 @@ async def get_per_expert_tokens_stats(self, request_dict: dict):
709710
Returns:
710711
tuple: response body, status code
711712
"""
713+
content, status_code = None, HTTPStatus.OK
712714
eplb_config = self.fd_config.eplb_config
713715
if not eplb_config.enable_eplb:
714716
content = {"code": 1, "msg": "redundant expert is disabled"}

fastdeploy/model_executor/layers/backends/xpu/moe/ep.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor):
277277
moe_topk=self.top_k,
278278
apply_norm_weight=True, # apply_norm_weight
279279
enable_softmax_top_k_fused=False,
280-
redundant_ep_rank_num_plus_one=layer.fd_config.model_config.redundant_experts_num + 1,
280+
redundant_ep_rank_num_plus_one=layer.fd_config.eplb_config.redundant_experts_num + 1,
281281
)
282282
else:
283283
topk_idx, topk_weights = fastdeploy.model_executor.ops.xpu.moe_topk_select(

fastdeploy/model_executor/layers/moe/ep.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,7 @@ def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor):
472472
expert_id_to_ep_rank_array=expert_id_to_ep_rank_array,
473473
expert_in_rank_num_list=expert_in_rank_num_list,
474474
tokens_per_expert_stats_list=tokens_per_expert_stats_list,
475-
redundant_ep_rank_num_plus_one=layer.fd_config.model_config.redundant_experts_num + 1,
475+
redundant_ep_rank_num_plus_one=layer.fd_config.eplb_config.redundant_experts_num + 1,
476476
)
477477
else:
478478
topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_redundant_topk_select(
@@ -484,7 +484,7 @@ def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor):
484484
moe_topk=self.top_k,
485485
apply_norm_weight=True,
486486
enable_softmax_top_k_fused=False,
487-
redundant_ep_rank_num_plus_one=layer.fd_config.model_config.redundant_experts_num + 1,
487+
redundant_ep_rank_num_plus_one=layer.fd_config.eplb_config.redundant_experts_num + 1,
488488
)
489489
else:
490490
if layer.topk_method == "noaux_tc":

fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def init_ep(self, layer: nn.Layer) -> None:
8484
"num_max_dispatch_tokens_per_rank": layer.fd_config.model_config.num_max_dispatch_tokens_per_rank,
8585
"ep_size": layer.ep_size,
8686
"ep_rank": layer.ep_rank,
87-
"redundant_experts_num": layer.fd_config.model_config.redundant_experts_num,
87+
"redundant_experts_num": layer.fd_config.eplb_config.redundant_experts_num,
8888
"ep_group": layer.fd_config.parallel_config.ep_group,
8989
}
9090

fastdeploy/model_executor/layers/moe/moe.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -467,13 +467,18 @@ def load_experts_weight(
467467
"""
468468
logical_expert_ids = [
469469
i
470+
% (
471+
self.fd_config.model_config.moe_num_experts[0]
472+
if isinstance(self.fd_config.model_config.moe_num_experts, list)
473+
else self.fd_config.model_config.moe_num_experts
474+
)
470475
for i in range(
471476
self.expert_id_offset,
472477
self.expert_id_offset + self.num_local_experts,
473478
)
474479
]
475480
ep_rank_to_expert_id_list = [i for i in range(self.num_experts)]
476-
if self.redundant_table_manger is not None and is_rearrange is True:
481+
if self.redundant_table_manger is not None:
477482
(
478483
ep_rank_to_expert_id_list,
479484
expert_id_to_ep_rank_array,
@@ -487,18 +492,15 @@ def load_experts_weight(
487492
down_proj_weights = []
488493
if isinstance(state_dict, list):
489494
state_dict = dict(state_dict)
490-
is_ffn_merged = (
491-
up_gate_proj_expert_weight_key.format(logical_expert_ids[0] if is_rearrange else self.expert_id_offset)
492-
in state_dict
493-
)
495+
is_ffn_merged = up_gate_proj_expert_weight_key.format(logical_expert_ids[0]) in state_dict
494496
if is_ffn_merged:
495497
for expert_idx in logical_expert_ids:
496498
down_proj_expert_weight_key_name = down_proj_expert_weight_key.format(expert_idx)
497499
up_gate_proj_expert_weight_key_name = up_gate_proj_expert_weight_key.format(expert_idx)
498500
up_gate_proj_weights.append(
499501
get_tensor(
500502
(
501-
state_dict.pop(up_gate_proj_expert_weight_key_name)
503+
state_dict[up_gate_proj_expert_weight_key_name]
502504
if up_gate_proj_expert_weight_key_name in state_dict
503505
else up_gate_proj_expert_weight_key_name
504506
),
@@ -508,7 +510,7 @@ def load_experts_weight(
508510
down_proj_weights.append(
509511
get_tensor(
510512
(
511-
state_dict.pop(down_proj_expert_weight_key_name)
513+
state_dict[down_proj_expert_weight_key_name]
512514
if down_proj_expert_weight_key_name in state_dict
513515
else down_proj_expert_weight_key_name
514516
),

fastdeploy/model_executor/load_weight_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,8 +252,13 @@ def get_expert_ranges(fd_config):
252252
"mtp_block" if getattr(fd_config.speculative_config, "model_type", "main") == "mtp" else "layers"
253253
)
254254

255+
moe_num_experts = fd_config.model_config.moe_num_experts
256+
if isinstance(moe_num_experts, list):
257+
moe_num_experts = moe_num_experts[0]
255258
for i in range(fd_config.model_config.moe_layer_start_index, fd_config.model_config.num_hidden_layers):
256259
for j in get_expert_ranges(fd_config):
260+
# Map redundant expert IDs back to actual expert IDs for weight loading
261+
j = j % moe_num_experts
257262
up_gate_proj_key = f"ernie.{prefix_layer_name}.{i}.mlp.experts.{j}.up_gate_proj.weight"
258263
down_proj_key = f"ernie.{prefix_layer_name}.{i}.mlp.experts.{j}.down_proj.weight"
259264

fastdeploy/model_executor/models/ernie4_5_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,7 @@ def __init__(
400400
self.redundant_table_manger = RedundantExpertManger(
401401
n_routed_experts=fd_config.model_config.moe_num_experts,
402402
num_hidden_layers=fd_config.model_config.num_hidden_layers,
403-
redundant_experts_num=fd_config.model_config.redundant_experts_num,
403+
redundant_experts_num=fd_config.eplb_config.redundant_experts_num,
404404
ep_size=fd_config.parallel_config.expert_parallel_size,
405405
)
406406

fastdeploy/worker/experts_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def __init__(
4242
self.num_hidden_layers = num_hidden_layers
4343

4444
self.num_replicas = self.num_expert + self.redundant_experts_num
45-
self.num_nodes = max(ep_size // 8, 1)
45+
self.num_nodes = max(ep_size // 8, 8)
4646
self.num_gpus = ep_size
4747
self.num_groups = 1
4848

fastdeploy/worker/worker_process.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -927,6 +927,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
927927
parallel_config = ParallelConfig(vars(args))
928928
cache_config = CacheConfig(vars(args))
929929
scheduler_config = SchedulerConfig(vars(args))
930+
eplb_config = EPLBConfig(args.eplb_config)
930931

931932
parallel_config.tensor_parallel_rank = local_rank % parallel_config.tensor_parallel_size
932933
parallel_config.data_parallel_rank = local_rank // parallel_config.tensor_parallel_size
@@ -940,9 +941,9 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
940941
if parallel_config.expert_parallel_size > 1:
941942
expert_parallel_rank = int(local_rank % parallel_config.expert_parallel_size)
942943
if isinstance(model_config.moe_num_experts, list):
943-
num_experts = model_config.moe_num_experts[0]
944+
num_experts = model_config.moe_num_experts[0] + eplb_config.redundant_experts_num
944945
else:
945-
num_experts = model_config.moe_num_experts
946+
num_experts = model_config.moe_num_experts + eplb_config.redundant_experts_num
946947
num_experts_per_rank = num_experts // parallel_config.expert_parallel_size
947948
num_experts_start_offset = expert_parallel_rank * num_experts_per_rank
948949
parallel_config.expert_parallel_rank = expert_parallel_rank
@@ -958,7 +959,6 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
958959
plas_attention_config = PlasAttentionConfig(args.plas_attention_config)
959960

960961
early_stop_config = EarlyStopConfig(args.early_stop_config)
961-
eplb_config = EPLBConfig(args.eplb_config)
962962

963963
structured_outputs_config: StructuredOutputsConfig = StructuredOutputsConfig(args=vars(args))
964964
routing_replay_config = RoutingReplayConfig(args.routing_replay_config)

0 commit comments

Comments
 (0)