From 20ef041d673d8cb3d86e0f062dcfef84db4a0636 Mon Sep 17 00:00:00 2001 From: xiaoxiaohehe001 Date: Wed, 7 Jan 2026 14:48:01 +0800 Subject: [PATCH 1/5] [Cherry-Pick] Support redundant expert for eplb --- .../gpu_ops/moe/ep_moe_expert_dispatch.cu | 20 +++++++++++++++++++ fastdeploy/entrypoints/engine_client.py | 2 ++ .../layers/backends/xpu/moe/ep.py | 2 +- fastdeploy/model_executor/layers/moe/ep.py | 4 ++-- .../layers/moe/fused_moe_backend_base.py | 2 +- fastdeploy/model_executor/layers/moe/moe.py | 11 ++++------ .../model_executor/load_weight_utils.py | 1 + .../model_executor/models/ernie4_5_moe.py | 2 +- fastdeploy/worker/experts_manager.py | 2 +- fastdeploy/worker/worker_process.py | 6 +++--- 10 files changed, 36 insertions(+), 16 deletions(-) diff --git a/custom_ops/gpu_ops/moe/ep_moe_expert_dispatch.cu b/custom_ops/gpu_ops/moe/ep_moe_expert_dispatch.cu index e4a1cc7f9cd..95cebf9c3f0 100644 --- a/custom_ops/gpu_ops/moe/ep_moe_expert_dispatch.cu +++ b/custom_ops/gpu_ops/moe/ep_moe_expert_dispatch.cu @@ -43,6 +43,11 @@ __VA_ARGS__ \ break; \ } \ + case 7: { \ + constexpr size_t NUM_EXPERTS_PER_RANK = 7; \ + __VA_ARGS__ \ + break; \ + } \ case 8: { \ constexpr size_t NUM_EXPERTS_PER_RANK = 8; \ __VA_ARGS__ \ @@ -53,11 +58,26 @@ __VA_ARGS__ \ break; \ } \ + case 10: { \ + constexpr size_t NUM_EXPERTS_PER_RANK = 10; \ + __VA_ARGS__ \ + break; \ + } \ case 16: { \ constexpr size_t NUM_EXPERTS_PER_RANK = 16; \ __VA_ARGS__ \ break; \ } \ + case 17: { \ + constexpr size_t NUM_EXPERTS_PER_RANK = 17; \ + __VA_ARGS__ \ + break; \ + } \ + case 20: { \ + constexpr size_t NUM_EXPERTS_PER_RANK = 20; \ + __VA_ARGS__ \ + break; \ + } \ case 32: { \ constexpr size_t NUM_EXPERTS_PER_RANK = 32; \ __VA_ARGS__ \ diff --git a/fastdeploy/entrypoints/engine_client.py b/fastdeploy/entrypoints/engine_client.py index c29e6d7f672..89505a8f97b 100644 --- a/fastdeploy/entrypoints/engine_client.py +++ b/fastdeploy/entrypoints/engine_client.py @@ -593,6 +593,7 @@ async def rearrange_experts(self, request_dict: dict): Returns: tuple: response body, status code """ + content, status_code = None, HTTPStatus.OK eplb_config = self.config.eplb_config if not eplb_config.enable_eplb: content = {"code": 1, "msg": "redundant expert is disabled"} @@ -695,6 +696,7 @@ async def get_per_expert_tokens_stats(self, request_dict: dict): Returns: tuple: response body, status code """ + content, status_code = None, HTTPStatus.OK eplb_config = self.config.eplb_config if not eplb_config.enable_eplb: content = {"code": 1, "msg": "redundant expert is disabled"} diff --git a/fastdeploy/model_executor/layers/backends/xpu/moe/ep.py b/fastdeploy/model_executor/layers/backends/xpu/moe/ep.py index 71c2dd600ff..76783860ea0 100644 --- a/fastdeploy/model_executor/layers/backends/xpu/moe/ep.py +++ b/fastdeploy/model_executor/layers/backends/xpu/moe/ep.py @@ -262,7 +262,7 @@ def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor): moe_topk=self.top_k, apply_norm_weight=True, # apply_norm_weight enable_softmax_top_k_fused=False, - redundant_ep_rank_num_plus_one=layer.fd_config.model_config.redundant_experts_num + 1, + redundant_ep_rank_num_plus_one=layer.fd_config.eplb_config.redundant_experts_num + 1, ) else: topk_idx, topk_weights = fastdeploy.model_executor.ops.xpu.moe_topk_select( diff --git a/fastdeploy/model_executor/layers/moe/ep.py b/fastdeploy/model_executor/layers/moe/ep.py index 0d215c115f8..4dd2f690d6d 100644 --- a/fastdeploy/model_executor/layers/moe/ep.py +++ b/fastdeploy/model_executor/layers/moe/ep.py @@ -449,7 +449,7 @@ def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor): expert_id_to_ep_rank_array=expert_id_to_ep_rank_array, expert_in_rank_num_list=expert_in_rank_num_list, tokens_per_expert_stats_list=tokens_per_expert_stats_list, - redundant_ep_rank_num_plus_one=layer.fd_config.model_config.redundant_experts_num + 1, + redundant_ep_rank_num_plus_one=layer.fd_config.eplb_config.redundant_experts_num + 1, ) else: topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_redundant_topk_select( @@ -461,7 +461,7 @@ def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor): moe_topk=self.top_k, apply_norm_weight=True, enable_softmax_top_k_fused=False, - redundant_ep_rank_num_plus_one=layer.fd_config.model_config.redundant_experts_num + 1, + redundant_ep_rank_num_plus_one=layer.fd_config.eplb_config.redundant_experts_num + 1, ) else: if layer.topk_method == "noaux_tc": diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py b/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py index 03191066713..f15dd7d706b 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py @@ -84,7 +84,7 @@ def init_ep(self, layer: nn.Layer) -> None: "num_max_dispatch_tokens_per_rank": layer.fd_config.model_config.num_max_dispatch_tokens_per_rank, "ep_size": layer.ep_size, "ep_rank": layer.ep_rank, - "redundant_experts_num": layer.fd_config.model_config.redundant_experts_num, + "redundant_experts_num": layer.fd_config.eplb_config.redundant_experts_num, "ep_group": layer.fd_config.parallel_config.ep_group, } diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index 4ad070aab6e..0dfa98134f8 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -463,7 +463,7 @@ def load_experts_weight( ) ] ep_rank_to_expert_id_list = [i for i in range(self.num_experts)] - if self.redundant_table_manger is not None and is_rearrange is True: + if self.redundant_table_manger is not None: ( ep_rank_to_expert_id_list, expert_id_to_ep_rank_array, @@ -477,10 +477,7 @@ def load_experts_weight( down_proj_weights = [] if isinstance(state_dict, list): state_dict = dict(state_dict) - is_ffn_merged = ( - up_gate_proj_expert_weight_key.format(logical_expert_ids[0] if is_rearrange else self.expert_id_offset) - in state_dict - ) + is_ffn_merged = up_gate_proj_expert_weight_key.format(logical_expert_ids[0]) in state_dict if is_ffn_merged: for expert_idx in logical_expert_ids: down_proj_expert_weight_key_name = down_proj_expert_weight_key.format(expert_idx) @@ -488,7 +485,7 @@ def load_experts_weight( up_gate_proj_weights.append( get_tensor( ( - state_dict.pop(up_gate_proj_expert_weight_key_name) + state_dict[up_gate_proj_expert_weight_key_name] if up_gate_proj_expert_weight_key_name in state_dict else up_gate_proj_expert_weight_key_name ), @@ -498,7 +495,7 @@ def load_experts_weight( down_proj_weights.append( get_tensor( ( - state_dict.pop(down_proj_expert_weight_key_name) + state_dict[down_proj_expert_weight_key_name] if down_proj_expert_weight_key_name in state_dict else down_proj_expert_weight_key_name ), diff --git a/fastdeploy/model_executor/load_weight_utils.py b/fastdeploy/model_executor/load_weight_utils.py index 69c91384fa7..1a98dbf2ea6 100644 --- a/fastdeploy/model_executor/load_weight_utils.py +++ b/fastdeploy/model_executor/load_weight_utils.py @@ -253,6 +253,7 @@ def get_expert_ranges(fd_config): for i in range(fd_config.model_config.moe_layer_start_index, fd_config.model_config.num_hidden_layers): for j in get_expert_ranges(fd_config): + j = j % fd_config.model_config.moe_num_experts up_gate_proj_key = f"ernie.{prefix_layer_name}.{i}.mlp.experts.{j}.up_gate_proj.weight" down_proj_key = f"ernie.{prefix_layer_name}.{i}.mlp.experts.{j}.down_proj.weight" diff --git a/fastdeploy/model_executor/models/ernie4_5_moe.py b/fastdeploy/model_executor/models/ernie4_5_moe.py index 436b03395a9..af40e5dbcb0 100644 --- a/fastdeploy/model_executor/models/ernie4_5_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_moe.py @@ -372,7 +372,7 @@ def __init__( self.redundant_table_manger = RedundantExpertManger( n_routed_experts=fd_config.model_config.moe_num_experts, num_hidden_layers=fd_config.model_config.num_hidden_layers, - redundant_experts_num=fd_config.model_config.redundant_experts_num, + redundant_experts_num=fd_config.eplb_config.redundant_experts_num, ep_size=fd_config.parallel_config.expert_parallel_size, ) diff --git a/fastdeploy/worker/experts_manager.py b/fastdeploy/worker/experts_manager.py index 4f6e4fe9255..df4cc024d35 100644 --- a/fastdeploy/worker/experts_manager.py +++ b/fastdeploy/worker/experts_manager.py @@ -42,7 +42,7 @@ def __init__( self.num_hidden_layers = num_hidden_layers self.num_replicas = self.num_expert + self.redundant_experts_num - self.num_nodes = max(ep_size // 8, 1) + self.num_nodes = 8 self.num_gpus = ep_size self.num_groups = 1 diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 17cebe0fa53..60d82e601ca 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -892,6 +892,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig: parallel_config = ParallelConfig(vars(args)) cache_config = CacheConfig(vars(args)) scheduler_config = SchedulerConfig(vars(args)) + eplb_config = EPLBConfig(args.eplb_config) parallel_config.tensor_parallel_rank = local_rank % parallel_config.tensor_parallel_size parallel_config.data_parallel_rank = local_rank // parallel_config.tensor_parallel_size @@ -899,9 +900,9 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig: if parallel_config.expert_parallel_size > 1: expert_parallel_rank = int(local_rank % parallel_config.expert_parallel_size) if isinstance(model_config.moe_num_experts, list): - num_experts = model_config.moe_num_experts[0] + num_experts = model_config.moe_num_experts[0] + eplb_config.redundant_experts_num else: - num_experts = model_config.moe_num_experts + num_experts = model_config.moe_num_experts + eplb_config.redundant_experts_num num_experts_per_rank = num_experts // parallel_config.expert_parallel_size num_experts_start_offset = expert_parallel_rank * num_experts_per_rank max_chips_per_node = 16 if current_platform.is_iluvatar() else 8 @@ -926,7 +927,6 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig: plas_attention_config = PlasAttentionConfig(args.plas_attention_config) early_stop_config = EarlyStopConfig(args.early_stop_config) - eplb_config = EPLBConfig(args.eplb_config) structured_outputs_config: StructuredOutputsConfig = StructuredOutputsConfig(args=vars(args)) From 3615487726ff786a00a69c95037178786689cfd6 Mon Sep 17 00:00:00 2001 From: xiaoxiaohehe001 Date: Thu, 8 Jan 2026 14:12:52 +0800 Subject: [PATCH 2/5] [Cherry-Pick] Support redundant expert for eplb --- fastdeploy/model_executor/layers/moe/moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index 0dfa98134f8..6ef5c76dece 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -456,7 +456,7 @@ def load_experts_weight( down_proj_expert_weight_key (str): The key of down_proj expert weight. """ logical_expert_ids = [ - i + i % self.fd_config.model_config.moe_num_experts for i in range( self.expert_id_offset, self.expert_id_offset + self.num_local_experts, From 277de52ec426d02764d80275f45b6afaa447cc37 Mon Sep 17 00:00:00 2001 From: xiaoxiaohehe001 Date: Thu, 8 Jan 2026 16:18:33 +0800 Subject: [PATCH 3/5] [Cherry-Pick] Support redundant expert for eplb --- fastdeploy/worker/experts_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastdeploy/worker/experts_manager.py b/fastdeploy/worker/experts_manager.py index df4cc024d35..8279c3b52ea 100644 --- a/fastdeploy/worker/experts_manager.py +++ b/fastdeploy/worker/experts_manager.py @@ -42,7 +42,7 @@ def __init__( self.num_hidden_layers = num_hidden_layers self.num_replicas = self.num_expert + self.redundant_experts_num - self.num_nodes = 8 + self.num_nodes = max(ep_size // 8, 8) self.num_gpus = ep_size self.num_groups = 1 From c07d52e774065b4992a64d374911c595859a486f Mon Sep 17 00:00:00 2001 From: xiaoxiaohehe001 Date: Thu, 8 Jan 2026 16:21:08 +0800 Subject: [PATCH 4/5] update --- fastdeploy/model_executor/load_weight_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/fastdeploy/model_executor/load_weight_utils.py b/fastdeploy/model_executor/load_weight_utils.py index 1a98dbf2ea6..25172c8858f 100644 --- a/fastdeploy/model_executor/load_weight_utils.py +++ b/fastdeploy/model_executor/load_weight_utils.py @@ -251,9 +251,12 @@ def get_expert_ranges(fd_config): "mtp_block" if getattr(fd_config.speculative_config, "model_type", "main") == "mtp" else "layers" ) + moe_num_experts = fd_config.model_config.moe_num_experts + if isinstance(moe_num_experts, list): + moe_num_experts = moe_num_experts[0] for i in range(fd_config.model_config.moe_layer_start_index, fd_config.model_config.num_hidden_layers): for j in get_expert_ranges(fd_config): - j = j % fd_config.model_config.moe_num_experts + j = j % moe_num_experts up_gate_proj_key = f"ernie.{prefix_layer_name}.{i}.mlp.experts.{j}.up_gate_proj.weight" down_proj_key = f"ernie.{prefix_layer_name}.{i}.mlp.experts.{j}.down_proj.weight" From 56f26360b2dc5b5728ddb7b532296442f605a8fc Mon Sep 17 00:00:00 2001 From: xiaoxiaohehe001 Date: Thu, 8 Jan 2026 16:31:08 +0800 Subject: [PATCH 5/5] update --- fastdeploy/model_executor/layers/moe/moe.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index 6ef5c76dece..19b87c76911 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -456,7 +456,12 @@ def load_experts_weight( down_proj_expert_weight_key (str): The key of down_proj expert weight. """ logical_expert_ids = [ - i % self.fd_config.model_config.moe_num_experts + i + % ( + self.fd_config.model_config.moe_num_experts[0] + if isinstance(self.fd_config.model_config.moe_num_experts, list) + else self.fd_config.model_config.moe_num_experts + ) for i in range( self.expert_id_offset, self.expert_id_offset + self.num_local_experts,