Skip to content

Conversation

@xiaoxiaohehe001
Copy link
Collaborator

Motivation

  • Support redundant expert for eplb
  • 启动服务时添加
    --eplb-config '{"redundant_experts_num": 32, "redundant_expert_async_load_model_shmem_size_gb": 10}'

Modifications

Usage or Command

Accuracy Tests

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

@paddle-bot
Copy link

paddle-bot bot commented Jan 7, 2026

Thanks for your contribution!

@codecov-commenter
Copy link

codecov-commenter commented Jan 7, 2026

Codecov Report

❌ Patch coverage is 16.66667% with 10 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@3ca99ab). Learn more about missing BASE report.

Files with missing lines Patch % Lines
fastdeploy/model_executor/load_weight_utils.py 0.00% 4 Missing ⚠️
fastdeploy/entrypoints/engine_client.py 0.00% 2 Missing ⚠️
fastdeploy/worker/worker_process.py 33.33% 2 Missing ⚠️
fastdeploy/model_executor/layers/moe/moe.py 50.00% 0 Missing and 1 partial ⚠️
fastdeploy/worker/experts_manager.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #5918   +/-   ##
==========================================
  Coverage           ?   67.02%           
==========================================
  Files              ?      348           
  Lines              ?    44673           
  Branches           ?     6876           
==========================================
  Hits               ?    29941           
  Misses             ?    12520           
  Partials           ?     2212           
Flag Coverage Δ
GPU 67.02% <16.66%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

这个PR为EPLB(Expert Parallel Load Balancing)添加了冗余专家支持功能。主要改动包括将冗余专家配置从 model_config 迁移到 eplb_config,并在专家并行计算中正确使用冗余专家数量。此外,还修复了一些潜在的bug,如变量未初始化问题。

主要变更:

  • eplb_config 初始化提前,以便在计算专家并行配置时使用冗余专家数量
  • 统一将冗余专家配置源从 model_config.redundant_experts_num 改为 eplb_config.redundant_experts_num
  • 在CUDA kernel中添加对7和17个专家/rank的支持,以适配冗余专家场景

Reviewed changes

Copilot reviewed 10 out of 10 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
fastdeploy/worker/worker_process.py 将eplb_config初始化提前,并在计算num_experts时加入redundant_experts_num
fastdeploy/worker/experts_manager.py 将num_nodes从动态计算改为硬编码值8
fastdeploy/model_executor/models/ernie4_5_moe.py 将redundant_experts_num的配置源从model_config改为eplb_config
fastdeploy/model_executor/load_weight_utils.py 添加模运算将冗余专家ID映射回实际专家ID以加载权重
fastdeploy/model_executor/layers/moe/moe.py 移除is_rearrange条件检查并修改state_dict访问方式
fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py 将redundant_experts_num的配置源从model_config改为eplb_config
fastdeploy/model_executor/layers/moe/ep.py 将redundant_experts_num的配置源从model_config改为eplb_config
fastdeploy/model_executor/layers/backends/xpu/moe/ep.py 将redundant_experts_num的配置源从model_config改为eplb_config
fastdeploy/entrypoints/engine_client.py 初始化content和status_code变量以避免未定义变量错误
custom_ops/gpu_ops/moe/ep_moe_expert_dispatch.cu 添加对7和17个专家/rank的支持


for i in range(fd_config.model_config.moe_layer_start_index, fd_config.model_config.num_hidden_layers):
for j in get_expert_ranges(fd_config):
j = j % fd_config.model_config.moe_num_experts
Copy link

Copilot AI Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fd_config.model_config.moe_num_experts 是列表时,模运算可能会失败。代码在第237行检查 moe_num_experts 是否为列表,但在第252行的模运算中直接使用了它,这会导致 TypeError: unsupported operand type(s) for %: 'int' and 'list'

建议修改为:

moe_num_experts = fd_config.model_config.moe_num_experts
if isinstance(moe_num_experts, list):
    moe_num_experts = moe_num_experts[0]
j = j % moe_num_experts
Suggested change
j = j % fd_config.model_config.moe_num_experts
moe_num_experts = fd_config.model_config.moe_num_experts
if isinstance(moe_num_experts, list):
moe_num_experts = moe_num_experts[0]
j = j % moe_num_experts

Copilot uses AI. Check for mistakes.

self.num_replicas = self.num_expert + self.redundant_experts_num
self.num_nodes = max(ep_size // 8, 1)
self.num_nodes = 8
Copy link

Copilot AI Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_nodesmax(ep_size // 8, 1) 硬编码为 8 可能会导致在 ep_size 小于 8 的场景下出现问题。原来的逻辑确保了至少有 1 个节点,而新的硬编码值 8 可能不适用于所有部署配置。建议要么恢复动态计算,要么添加验证逻辑确保 ep_size >= 8,或者在代码注释中说明为什么硬编码为 8 是合理的。

Suggested change
self.num_nodes = 8
# Determine number of nodes based on ep_size to support small ep_size values
self.num_nodes = max(ep_size // 8, 1)

Copilot uses AI. Check for mistakes.
@xiaoxiaohehe001 xiaoxiaohehe001 changed the title [BugFix] Support redundant expert for eplb [Feature] Support redundant expert for eplb Jan 7, 2026
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 10 out of 10 changed files in this pull request and generated 4 comments.

Comments suppressed due to low confidence (1)

fastdeploy/entrypoints/engine_client.py:646

  • 在函数开始处初始化 content, status_code = None, HTTPStatus.OK 是一个改进,确保了这两个变量总是有定义的值。但是,需要注意在 line 644-646 的逻辑中,当 content 为 None 且 "ips" 不在请求中时会设置错误,但此时 status_code 已经在 line 643 被设置为 BAD_REQUEST。这样会导致 line 648 的检查可能使用不一致的状态。建议在设置 content 时始终同步设置 status_code,或者重构逻辑使其更清晰。
        content, status_code = None, HTTPStatus.OK
        eplb_config = self.fd_config.eplb_config
        if not eplb_config.enable_eplb:
            content = {"code": 1, "msg": "redundant expert is disabled"}
            status_code = HTTPStatus.BAD_REQUEST
            return content, status_code

        if (
            request_dict.get("user", "") != eplb_config.redundant_expert_api_user
            or request_dict.get("passwd", "") != eplb_config.redundant_expert_api_password
        ):
            content = {"code": 1, "msg": "user or passwd is invalid"}
            status_code = HTTPStatus.UNAUTHORIZED
            return content, status_code

        if self.fd_config.parallel_config.tensor_parallel_rank != 0:
            content = {
                "code": 1,
                "msg": f"actual rank {self.fd_config.parallel_config.tensor_parallel_rank}, expect rank 0",
            }
            status_code = HTTPStatus.BAD_REQUEST
            return content, status_code

        action = request_dict.get("action", "")
        api_server_logger.info(f"redundant_expert: rearrange_experts recv request, action {action}")
        if action == "":
            # action: start rearrange experts
            # params: {'user': 'xxx', 'passwd': 'xxx', 'ips': ['10.54.99.77:8000', '10.54.99.77:8300']}
            if self.rearrange_experts_signal.value[0] != RearrangeExpertStatus.FREE.value:
                content = {
                    "code": 1,
                    "msg": f"rearrange is doing. actual status {self.rearrange_experts_signal.value[0]}, expect status {RearrangeExpertStatus.FREE.value}",
                }
                status_code = HTTPStatus.BAD_REQUEST
            if "ips" not in request_dict and content is None:
                content = {"code": 1, "msg": "ips in request is None"}
                status_code = HTTPStatus.BAD_REQUEST

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 10 out of 10 changed files in this pull request and generated 3 comments.

"""
logical_expert_ids = [
i
i % self.fd_config.model_config.moe_num_experts
Copy link

Copilot AI Jan 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

添加了模运算来将冗余专家 ID 映射回实际的专家 ID,这是一个关键的逻辑变更。建议在此处添加英文注释说明为什么需要这个映射,例如:"Map redundant expert IDs back to actual expert IDs since redundant experts share weights with actual experts"。

Copilot uses AI. Check for mistakes.
Comment on lines 253 to 256
moe_num_experts = fd_config.model_config.moe_num_experts
if isinstance(moe_num_experts, list):
moe_num_experts = moe_num_experts[0]
j = j % moe_num_experts
Copy link

Copilot AI Jan 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在循环内部重复执行 moe_num_experts 的提取和类型检查可能效率较低。建议将第 253-256 行的逻辑移到循环外部,在开始迭代之前执行一次即可。

Copilot uses AI. Check for mistakes.
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot encountered an error and was unable to review this pull request. You can try again by re-requesting a review.

yuanlehome
yuanlehome previously approved these changes Jan 9, 2026
yuanlehome pushed a commit that referenced this pull request Jan 9, 2026
* [Cherry-Pick] Support redundant expert for eplb

* [Cherry-Pick] Support redundant expert for eplb

* [Cherry-Pick] Support redundant expert for eplb

* update

* update
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 11 out of 11 changed files in this pull request and generated 5 comments.

up_gate_proj_expert_weight_key.format(logical_expert_ids[0] if is_rearrange else self.expert_id_offset)
in state_dict
)
is_ffn_merged = up_gate_proj_expert_weight_key.format(logical_expert_ids[0]) in state_dict
Copy link

Copilot AI Jan 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

移除了 is_rearrange 参数的检查,现在只要 redundant_table_manger 不为 None 就会执行重排逻辑。这个改动简化了逻辑,但需要确保所有调用此方法的地方都已更新,不再传递 is_rearrange 参数。此外,在第 495 行生成 is_ffn_merged 的检查中,也移除了对 is_rearrange 的依赖,确保逻辑的一致性。建议验证此改动不会影响非重排场景下的权重加载行为。

Suggested change
is_ffn_merged = up_gate_proj_expert_weight_key.format(logical_expert_ids[0]) in state_dict
if logical_expert_ids:
first_expert_key = up_gate_proj_expert_weight_key.format(logical_expert_ids[0])
is_ffn_merged = first_expert_key in state_dict
else:
# No local experts found, fall back to non-merged FFN loading path
is_ffn_merged = False

Copilot uses AI. Check for mistakes.
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot encountered an error and was unable to review this pull request. You can try again by re-requesting a review.

@Jiang-Jia-Jun Jiang-Jia-Jun merged commit 00a01ae into PaddlePaddle:develop Jan 9, 2026
23 of 26 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants