diff --git a/custom_ops/gpu_ops/moe/ep_moe_expert_dispatch.cu b/custom_ops/gpu_ops/moe/ep_moe_expert_dispatch.cu index 52e7abfc3bc..3550aa063bf 100644 --- a/custom_ops/gpu_ops/moe/ep_moe_expert_dispatch.cu +++ b/custom_ops/gpu_ops/moe/ep_moe_expert_dispatch.cu @@ -48,6 +48,11 @@ __VA_ARGS__ \ break; \ } \ + case 7: { \ + constexpr size_t NUM_EXPERTS_PER_RANK = 7; \ + __VA_ARGS__ \ + break; \ + } \ case 8: { \ constexpr size_t NUM_EXPERTS_PER_RANK = 8; \ __VA_ARGS__ \ @@ -68,6 +73,11 @@ __VA_ARGS__ \ break; \ } \ + case 17: { \ + constexpr size_t NUM_EXPERTS_PER_RANK = 17; \ + __VA_ARGS__ \ + break; \ + } \ case 20: { \ constexpr size_t NUM_EXPERTS_PER_RANK = 20; \ __VA_ARGS__ \ diff --git a/fastdeploy/entrypoints/engine_client.py b/fastdeploy/entrypoints/engine_client.py index 96af0dc6540..47788814770 100644 --- a/fastdeploy/entrypoints/engine_client.py +++ b/fastdeploy/entrypoints/engine_client.py @@ -607,6 +607,7 @@ async def rearrange_experts(self, request_dict: dict): Returns: tuple: response body, status code """ + content, status_code = None, HTTPStatus.OK eplb_config = self.fd_config.eplb_config if not eplb_config.enable_eplb: content = {"code": 1, "msg": "redundant expert is disabled"} @@ -709,6 +710,7 @@ async def get_per_expert_tokens_stats(self, request_dict: dict): Returns: tuple: response body, status code """ + content, status_code = None, HTTPStatus.OK eplb_config = self.fd_config.eplb_config if not eplb_config.enable_eplb: content = {"code": 1, "msg": "redundant expert is disabled"} diff --git a/fastdeploy/model_executor/layers/backends/xpu/moe/ep.py b/fastdeploy/model_executor/layers/backends/xpu/moe/ep.py index b9c497734d1..bd8263b8610 100644 --- a/fastdeploy/model_executor/layers/backends/xpu/moe/ep.py +++ b/fastdeploy/model_executor/layers/backends/xpu/moe/ep.py @@ -277,7 +277,7 @@ def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor): moe_topk=self.top_k, apply_norm_weight=True, # apply_norm_weight enable_softmax_top_k_fused=False, - redundant_ep_rank_num_plus_one=layer.fd_config.model_config.redundant_experts_num + 1, + redundant_ep_rank_num_plus_one=layer.fd_config.eplb_config.redundant_experts_num + 1, ) else: topk_idx, topk_weights = fastdeploy.model_executor.ops.xpu.moe_topk_select( diff --git a/fastdeploy/model_executor/layers/moe/ep.py b/fastdeploy/model_executor/layers/moe/ep.py index a5a8cbec322..adf80d717e4 100644 --- a/fastdeploy/model_executor/layers/moe/ep.py +++ b/fastdeploy/model_executor/layers/moe/ep.py @@ -472,7 +472,7 @@ def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor): expert_id_to_ep_rank_array=expert_id_to_ep_rank_array, expert_in_rank_num_list=expert_in_rank_num_list, tokens_per_expert_stats_list=tokens_per_expert_stats_list, - redundant_ep_rank_num_plus_one=layer.fd_config.model_config.redundant_experts_num + 1, + redundant_ep_rank_num_plus_one=layer.fd_config.eplb_config.redundant_experts_num + 1, ) else: topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_redundant_topk_select( @@ -484,7 +484,7 @@ def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor): moe_topk=self.top_k, apply_norm_weight=True, enable_softmax_top_k_fused=False, - redundant_ep_rank_num_plus_one=layer.fd_config.model_config.redundant_experts_num + 1, + redundant_ep_rank_num_plus_one=layer.fd_config.eplb_config.redundant_experts_num + 1, ) else: if layer.topk_method == "noaux_tc": diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py b/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py index 729295d9244..3974d81d47b 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py @@ -84,7 +84,7 @@ def init_ep(self, layer: nn.Layer) -> None: "num_max_dispatch_tokens_per_rank": layer.fd_config.model_config.num_max_dispatch_tokens_per_rank, "ep_size": layer.ep_size, "ep_rank": layer.ep_rank, - "redundant_experts_num": layer.fd_config.model_config.redundant_experts_num, + "redundant_experts_num": layer.fd_config.eplb_config.redundant_experts_num, "ep_group": layer.fd_config.parallel_config.ep_group, } diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index c173b650416..7ba53d9f288 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -467,13 +467,18 @@ def load_experts_weight( """ logical_expert_ids = [ i + % ( + self.fd_config.model_config.moe_num_experts[0] + if isinstance(self.fd_config.model_config.moe_num_experts, list) + else self.fd_config.model_config.moe_num_experts + ) for i in range( self.expert_id_offset, self.expert_id_offset + self.num_local_experts, ) ] ep_rank_to_expert_id_list = [i for i in range(self.num_experts)] - if self.redundant_table_manger is not None and is_rearrange is True: + if self.redundant_table_manger is not None: ( ep_rank_to_expert_id_list, expert_id_to_ep_rank_array, @@ -487,10 +492,7 @@ def load_experts_weight( down_proj_weights = [] if isinstance(state_dict, list): state_dict = dict(state_dict) - is_ffn_merged = ( - up_gate_proj_expert_weight_key.format(logical_expert_ids[0] if is_rearrange else self.expert_id_offset) - in state_dict - ) + is_ffn_merged = up_gate_proj_expert_weight_key.format(logical_expert_ids[0]) in state_dict if is_ffn_merged: for expert_idx in logical_expert_ids: down_proj_expert_weight_key_name = down_proj_expert_weight_key.format(expert_idx) @@ -498,7 +500,7 @@ def load_experts_weight( up_gate_proj_weights.append( get_tensor( ( - state_dict.pop(up_gate_proj_expert_weight_key_name) + state_dict[up_gate_proj_expert_weight_key_name] if up_gate_proj_expert_weight_key_name in state_dict else up_gate_proj_expert_weight_key_name ), @@ -508,7 +510,7 @@ def load_experts_weight( down_proj_weights.append( get_tensor( ( - state_dict.pop(down_proj_expert_weight_key_name) + state_dict[down_proj_expert_weight_key_name] if down_proj_expert_weight_key_name in state_dict else down_proj_expert_weight_key_name ), diff --git a/fastdeploy/model_executor/load_weight_utils.py b/fastdeploy/model_executor/load_weight_utils.py index c243d88c112..7a0b676a8b0 100644 --- a/fastdeploy/model_executor/load_weight_utils.py +++ b/fastdeploy/model_executor/load_weight_utils.py @@ -252,8 +252,13 @@ def get_expert_ranges(fd_config): "mtp_block" if getattr(fd_config.speculative_config, "model_type", "main") == "mtp" else "layers" ) + moe_num_experts = fd_config.model_config.moe_num_experts + if isinstance(moe_num_experts, list): + moe_num_experts = moe_num_experts[0] for i in range(fd_config.model_config.moe_layer_start_index, fd_config.model_config.num_hidden_layers): for j in get_expert_ranges(fd_config): + # Map redundant expert IDs back to actual expert IDs for weight loading + j = j % moe_num_experts up_gate_proj_key = f"ernie.{prefix_layer_name}.{i}.mlp.experts.{j}.up_gate_proj.weight" down_proj_key = f"ernie.{prefix_layer_name}.{i}.mlp.experts.{j}.down_proj.weight" diff --git a/fastdeploy/model_executor/models/ernie4_5_moe.py b/fastdeploy/model_executor/models/ernie4_5_moe.py index 8d072e791e0..3e540958fa5 100644 --- a/fastdeploy/model_executor/models/ernie4_5_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_moe.py @@ -400,7 +400,7 @@ def __init__( self.redundant_table_manger = RedundantExpertManger( n_routed_experts=fd_config.model_config.moe_num_experts, num_hidden_layers=fd_config.model_config.num_hidden_layers, - redundant_experts_num=fd_config.model_config.redundant_experts_num, + redundant_experts_num=fd_config.eplb_config.redundant_experts_num, ep_size=fd_config.parallel_config.expert_parallel_size, ) diff --git a/fastdeploy/worker/experts_manager.py b/fastdeploy/worker/experts_manager.py index 4f6e4fe9255..8279c3b52ea 100644 --- a/fastdeploy/worker/experts_manager.py +++ b/fastdeploy/worker/experts_manager.py @@ -42,7 +42,7 @@ def __init__( self.num_hidden_layers = num_hidden_layers self.num_replicas = self.num_expert + self.redundant_experts_num - self.num_nodes = max(ep_size // 8, 1) + self.num_nodes = max(ep_size // 8, 8) self.num_gpus = ep_size self.num_groups = 1 diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 040e6c39d40..0ba85c147bf 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -927,6 +927,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig: parallel_config = ParallelConfig(vars(args)) cache_config = CacheConfig(vars(args)) scheduler_config = SchedulerConfig(vars(args)) + eplb_config = EPLBConfig(args.eplb_config) parallel_config.tensor_parallel_rank = local_rank % parallel_config.tensor_parallel_size parallel_config.data_parallel_rank = local_rank // parallel_config.tensor_parallel_size @@ -940,9 +941,9 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig: if parallel_config.expert_parallel_size > 1: expert_parallel_rank = int(local_rank % parallel_config.expert_parallel_size) if isinstance(model_config.moe_num_experts, list): - num_experts = model_config.moe_num_experts[0] + num_experts = model_config.moe_num_experts[0] + eplb_config.redundant_experts_num else: - num_experts = model_config.moe_num_experts + num_experts = model_config.moe_num_experts + eplb_config.redundant_experts_num num_experts_per_rank = num_experts // parallel_config.expert_parallel_size num_experts_start_offset = expert_parallel_rank * num_experts_per_rank parallel_config.expert_parallel_rank = expert_parallel_rank @@ -958,7 +959,6 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig: plas_attention_config = PlasAttentionConfig(args.plas_attention_config) early_stop_config = EarlyStopConfig(args.early_stop_config) - eplb_config = EPLBConfig(args.eplb_config) structured_outputs_config: StructuredOutputsConfig = StructuredOutputsConfig(args=vars(args)) routing_replay_config = RoutingReplayConfig(args.routing_replay_config) diff --git a/tests/model_executor/test_ep.py b/tests/model_executor/test_ep.py index e1229145147..373e8899396 100644 --- a/tests/model_executor/test_ep.py +++ b/tests/model_executor/test_ep.py @@ -460,7 +460,7 @@ def get_ep_rank_to_expert_id_list_by_layer(self, _layer_idx): top_k=2, routed_scaling_factor=1.0, gate_correction_bias=None, - fd_config=SimpleNamespace(model_config=SimpleNamespace(redundant_experts_num=0)), + fd_config=SimpleNamespace(eplb_config=SimpleNamespace(redundant_experts_num=0)), ) gate_out = paddle.randn([1, 4], dtype="float32")