diff --git a/paddleformers/nn/moe_deepep/__init__.py b/paddleformers/nn/moe_deepep/__init__.py new file mode 100644 index 0000000000..47305c2bdd --- /dev/null +++ b/paddleformers/nn/moe_deepep/__init__.py @@ -0,0 +1,41 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from contextlib import suppress +from typing import TYPE_CHECKING + +from ...utils.lazy_import import _LazyModule + +import_structure = { + "modular_moe_layer": ["ModularMoELayer"], + "moe_communication": ["MoECommunicationInterface", "StandardMoECommunication", "DeepEPMoECommunication"], + "moe_expert": ["MoEExpertInterface", "StandardMoEExpert", "Qwen2MLP"], + "moe_gate": ["PretrainedMoEGate"], + "moe_factory": ["QuickAccessMoEFactory"], +} + +if TYPE_CHECKING: + from .modular_moe_layer import * + from .moe_communication import * + from .moe_expert import * + from .moe_factory import * + from .moe_gate import * +else: + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + import_structure, + module_spec=__spec__, + ) diff --git a/paddleformers/nn/moe_deepep/modular_moe_layer.py b/paddleformers/nn/moe_deepep/modular_moe_layer.py new file mode 100644 index 0000000000..6ed4e853cd --- /dev/null +++ b/paddleformers/nn/moe_deepep/modular_moe_layer.py @@ -0,0 +1,470 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +import os +from typing import Any, Dict, Optional + +import paddle +import paddle.distributed as dist +from paddle import nn + +from .moe_communication import ( + DeepEPMoECommunication, + MoECommunicationInterface, + StandardMoECommunication, +) +from .moe_expert import MoEExpertInterface, Qwen2MoeMLP, StandardMoEExpert +from .moe_gate import FlexibleMoEGate, PretrainedMoEGate +from .moe_loss import LossCombiner, LossConfig, LossFunction, LossRegistry, LossType +from .moe_loss_instance import get_global_loss_registry + +logger = logging.getLogger(__name__) +global_loss_registry = get_global_loss_registry() + + +class ModularMoELayer(nn.Layer): + """ + 模块化MoE Layer EP并行实现 + + 设计理念: + 1. 高度模块化:门控、专家、通信完全解耦 + 2. 易于扩展:支持自定义门控策略和专家架构 + """ + + def __init__( + self, + hidden_size: int, + moe_intermediate_size: int, + num_experts: int, + num_shared_experts: int, + num_experts_per_tok: int, + norm_topk_prob: int, + expert_activation: str, + moe_config: Dict, + ): + """ + 初始化模块化MoE Layer + + Args: + hidden_size: 隐藏维度 + moe_intermediate_size: MoE 中间维度 + num_experts: 专家数量 + num_experts_per_tok: 每个token选择的专家数(TopK) + num_shared_experts: 共享专家数量 + norm_topk_prob: 是否归一化TopK的概率 + expert_activation: 专家使用的激活函数 + moe_config: 其他 MoE 相关配置 + + + moe_config 内参数: + moe_group: MoE通信组 + custom_gate: 自定义门控网络 + custom_expert: 自定义专家网络 + custom_communication: 自定义通信策略 + expert_parallel_degree: EP 并行度 + gate_activation: 门控激活函数 + aux_loss_weight: 辅助损失权重(传统模式) + z_loss_weight: Z损失权重(传统模式) + train_topk_method: 训练时使用的 TopK 具体方法 + inference_topk_method: 推理时使用的 TopK 具体方法 + drop_tokens: 是否在 Expert 满后抛弃 Token + use_flexible_loss: 是否使用灵活损失系统 + loss_configs: 损失配置列表(灵活模式) + loss_combiner_name: 损失组合器名称 + expert_dropout: 专家dropout + """ + super().__init__() + self.hidden_size = hidden_size + self.num_experts = num_experts + self.num_experts_per_tok = num_experts_per_tok + self.num_shared_experts = num_shared_experts + self.moe_intermediate_size = moe_intermediate_size + self.expert_activation = expert_activation + self.norm_topk_prob = norm_topk_prob + + self.moe_group = moe_config.get("moe_group", "data") + self.custom_gate = moe_config.get("custom_gate", None) + self.custom_expert = moe_config.get("custom_expert", None) + self.custom_communication = moe_config.get("custom_communication", None) + self.expert_parallel_degree = moe_config.get("expert_parallel_degree", 1) + self.gate_activation = moe_config.get("gate_activation", "softmax") + self.aux_loss_weight = moe_config.get("aux_loss_weight", 0.01) + self.z_loss_weight = moe_config.get("z_loss_weight", 0.0) + self.topk_method = ( + moe_config.get("train_topk_method", "greedy") + if self.training + else moe_config.get("inference_topk_method", "greedy") + ) + self.drop_tokens = moe_config.get("drop_tokens", True) + self.use_flexible_loss = moe_config.get("use_flexible_loss", False) + self.expert_dropout = moe_config.get("expert_dropout", 0.0) + self.loss_configs = moe_config.get("loss_configs", None) + self.loss_combiner_name = moe_config.get("loss_combiner_name", "weighted_sum") + + # 初始化EP并行相关参数 + self._init_expert_parallel() + # 创建门控网络 + if self.custom_gate is not None: + self.gate = self.custom_gate + elif self.use_flexible_loss: + # 使用灵活损失系统 + if self.loss_configs is None: + self.loss_configs = [ + LossConfig("auxiliary", LossType.AUXILIARY, weight=self.aux_loss_weight), + LossConfig("z_loss", LossType.Z_LOSS, weight=self.z_loss_weight), + ] + self.gate = FlexibleMoEGate( + num_experts=self.num_experts, + expert_hidden_size=self.hidden_size, + drop_tokens=self.drop_tokens, + topk_method=self.topk_method, + num_experts_per_tok=self.num_experts_per_tok, + norm_topk_prob=self.norm_topk_prob, + moe_config=moe_config, + loss_registry=global_loss_registry, + loss_configs=self.loss_configs, + loss_combiner_name=self.loss_combiner_name, + ) + else: + self.gate = PretrainedMoEGate( + num_experts=self.num_experts, + expert_hidden_size=self.hidden_size, + drop_tokens=self.drop_tokens, + topk_method=self.topk_method, + num_experts_per_tok=self.num_experts_per_tok, + norm_topk_prob=self.norm_topk_prob, + moe_config=moe_config, + ) + + # 创建专家网络 + if self.custom_expert is not None: + # 如果传入的是实例,直接使用 + if isinstance(self.custom_expert, MoEExpertInterface): + expert_class = type(self.custom_expert) + else: + expert_class = self.custom_expert + else: + expert_class = Qwen2MoeMLP + # expert_class = StandardMoEExpert + + self.experts = nn.LayerList([]) + for i in range(self.num_experts): + if i // self.num_experts_per_device == self.moe_rank: + expert = expert_class( + hidden_size=self.hidden_size, + intermediate_size=self.moe_intermediate_size, + expert_activation=self.expert_activation, + expert_dropout=self.expert_dropout, + config={}, + ) + self.experts.append(expert) + else: + # 创建一个空的Layer作为占位符 + empty_expert = nn.Layer() + self.experts.append(empty_expert) + + # 创建共享专家 + if self.num_shared_experts > 0: + expert = expert_class( + hidden_size=self.hidden_size, + intermediate_size=self.moe_intermediate_size * self.num_shared_experts, + expert_activation=self.expert_activation, + expert_dropout=self.expert_dropout, + config={}, + ) + else: + self.shared_experts = None + + # 创建通信策略 + if self.custom_communication is not None: + self.communication = self.custom_communication + else: + if os.getenv("USE_DEEPEP", "0"): + self.communication = DeepEPMoECommunication() + else: + self.communication = StandardMoECommunication() + + def _init_expert_parallel(self): + """ + 初始化专家并行相关参数 + """ + + def _parse_moe_expert_parallel(num_experts: int, expert_parallel_degree: int) -> int: + """ + 解析MoE专家并行参数 + + Args: + num_experts: 专家总数 + expert_parallel_degree: 专家并行度 + + Returns: + moe_num_experts_per_device: 每个设备的专家数 + """ + assert ( + num_experts >= expert_parallel_degree + ), f"expert num_experts={num_experts} >= moe_world_size={expert_parallel_degree}" + assert ( + num_experts % expert_parallel_degree == 0 + ), f"expert num_experts={num_experts} % moe_world_size={expert_parallel_degree} == 0" + + moe_num_experts_per_device = num_experts // expert_parallel_degree + return moe_num_experts_per_device + + try: + dist.fleet.get_hybrid_communicate_group() + is_fleet_init = True + except AttributeError as e: + is_fleet_init = False + + # print("is_fleet_init = ", is_fleet_init) + if ( + is_fleet_init + and dist.fleet.get_hybrid_communicate_group().get_data_parallel_world_size() > 1 + ): + if self.moe_group == "data": + self.moe_group = dist.fleet.get_hybrid_communicate_group().get_data_parallel_group() + elif self.moe_group == "expert": + self.moe_group = dist.fleet.get_hybrid_communicate_group().expert_parallel_group + self.moe_rank = dist.get_rank(self.moe_group) + # print("self.moe_group, ", self.moe_group) + # print("self.moe_rank before, ", self.moe_rank) + self.moe_rank = 0 if self.moe_rank < 0 else self.moe_rank + new_expert_parallel_degree = dist.get_world_size(self.moe_group) + assert (self.expert_parallel_degree == new_expert_parallel_degree), f"self.expert_parallel_degree={self.expert_parallel_degree} != moe_world_size={new_expert_parallel_degree}" + # print("self.expert_parallel_degree before, ", self.expert_parallel_degree) + self.expert_parallel_degree = 1 if self.expert_parallel_degree < 0 else self.expert_parallel_degree + self.num_experts_per_device = _parse_moe_expert_parallel(self.num_experts, self.expert_parallel_degree) + else: + self.moe_group = None + self.moe_rank = 0 + self.expert_parallel_degree = 1 + self.num_experts_per_device = self.num_experts + + def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: + """ + MoE Layer前向传播 + + Args: + hidden_states: 输入隐藏状态,形状: [batch_size, seq_len, hidden_size] + + Returns: + output: 输出隐藏状态,形状: [batch_size, seq_len, hidden_size] + """ + batch_size, seq_len, d_model = hidden_states.shape + residuals = hidden_states + + # 门控前向传播 + capacity, topk_weights, topk_indices, priorities, aux_loss, z_loss = self.gate(hidden_states) # topkgating + # topk_weights, topk_indices, exp_counts, l_aux, l_zloss = self.gate(hidden_states) # topkgating_nodrop + + # 重塑输入 + reshaped_input = hidden_states.reshape([-1, d_model]) + + # MoE前向传播 + if self.expert_parallel_degree > 1: + # 使用EP并行 + # print("----------------- using _forward_with_ep_parallel") + output = self._forward_with_ep_parallel(reshaped_input, topk_indices, topk_weights) + else: + # 使用传统MoE + # print("----------------- using _forward_traditional_moe") + output = self._forward_traditional_moe(reshaped_input, topk_indices, topk_weights) + + # 恢复原始形状 + output = output.reshape([batch_size, seq_len, d_model]) + + # 添加共享专家输出 + if self.shared_experts is not None: + shared_output = self.shared_experts(residuals) + output = output + shared_output + + # currently no need return aux_loss and z_loss + return output, None + + def _forward_traditional_moe( + self, hidden_states: paddle.Tensor, selected_experts: paddle.Tensor, topk_weights: paddle.Tensor + ) -> paddle.Tensor: + """ + 传统MoE前向传播 + + Args: + hidden_states: 输入隐藏状态,形状: [batch_size*seq_len, hidden_size] + selected_experts: TopK专家索引,形状: [seq_len, num_experts_per_tok] + topk_weights: TopK权重,形状: [seq_len, num_experts_per_tok] + + Returns: + output: 输出隐藏状态,形状: [seq_len, hidden_size] + """ + _, d_model = hidden_states.shape + final_hidden_states = paddle.zeros_like(hidden_states, dtype=hidden_states.dtype) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = paddle.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).transpose([2, 1, 0]) + # [num_experts, topk, bs*seq] + tokens_per_expert = expert_mask.reshape([expert_mask.shape[0], -1]).sum(axis=-1) + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + top_x, idx = paddle.where(expert_mask[expert_idx]) + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + if tokens_per_expert[expert_idx] <= 0.1: + continue + current_state = hidden_states[idx, None].reshape([-1, d_model]) + current_hidden_states = expert_layer(current_state) * topk_weights[idx, top_x].unsqueeze(-1) + final_hidden_states.index_add_( + index=idx.reshape([-1]), axis=0, value=current_hidden_states.to(hidden_states.dtype) + ) + + return final_hidden_states.cast(hidden_states.dtype) + + def _forward_with_ep_parallel( + self, hidden_states: paddle.Tensor, topk_indices: paddle.Tensor, topk_weights: paddle.Tensor + ) -> paddle.Tensor: + """ + EP并行MoE前向传播 + + Args: + hidden_states: 输入隐藏状态,形状: [seq_len, hidden_size] + topk_indices: TopK专家索引,形状: [seq_len, num_experts_per_tok] + topk_weights: TopK权重,形状: [seq_len, num_experts_per_tok] + + Returns: + output: 输出隐藏状态,形状: [seq_len, hidden_size] + """ + # 使用通信策略进行EP并行 + output, aux_loss, z_loss = self.communication.forward( + hidden_states, + topk_indices, + topk_weights, + self.expert_parallel_degree, + self.moe_group, + self.experts, + self.moe_rank, + self.num_experts_per_device, + self.num_experts, + self.num_experts_per_tok, + ) + return output + + def get_auxiliary_loss(self) -> paddle.Tensor: + """ + 获取辅助损失 + + Returns: + aux_loss: 辅助损失,标量 + """ + return self.gate.get_auxiliary_loss() + + def get_z_loss(self) -> paddle.Tensor: + """ + 获取Z损失 + + Returns: + z_loss: Z损失,标量 + """ + return self.gate.get_z_loss() + + def get_all_losses(self) -> Dict[str, paddle.Tensor]: + """获取所有损失(灵活模式)""" + if hasattr(self.gate, "get_all_losses"): + return self.gate.get_all_losses() + else: + return {"auxiliary": self.get_auxiliary_loss(), "z_loss": self.get_z_loss()} + + def get_total_loss(self) -> paddle.Tensor: + """获取总损失(灵活模式)""" + if hasattr(self.gate, "get_total_loss"): + return self.gate.get_total_loss() + else: + return self.get_auxiliary_loss() + self.get_z_loss() + + # 灵活损失管理方法 + def add_loss_function( + self, + name: str, + loss_func: LossFunction, + weight: float = 0.0, + loss_type: LossType = LossType.CUSTOM, + enabled: bool = True, + params: Optional[Dict[str, Any]] = None, + ): + """添加自定义损失函数""" + if not self.use_flexible_loss: + logger.warning("当前使用传统损失模式,无法添加自定义损失函数") + return + + # 注册损失函数 + loss_registry.register_loss(name, loss_func) + + # 添加损失配置 + config = LossConfig(name, loss_type, weight, enabled, params or {}) + if hasattr(self.gate, "add_loss_config"): + self.gate.add_loss_config(config) + else: + logger.warning("当前门控层不支持动态添加损失函数") + + def remove_loss_function(self, name: str): + """移除损失函数""" + if not self.use_flexible_loss: + logger.warning("当前使用传统损失模式,无法移除损失函数") + return + + if hasattr(self.gate, "remove_loss_config"): + self.gate.remove_loss_config(name) + else: + logger.warning("当前门控层不支持动态移除损失函数") + + def update_loss_weights(self, weights: Dict[str, float]): + """更新损失权重""" + if not self.use_flexible_loss: + logger.warning("当前使用传统损失模式,无法动态更新损失权重") + return + + if hasattr(self.gate, "update_loss_weights"): + self.gate.update_loss_weights(weights) + else: + logger.warning("当前门控层不支持动态更新损失权重") + + def set_loss_combiner(self, combiner_name: str): + """设置损失组合器""" + if not self.use_flexible_loss: + logger.warning("当前使用传统损失模式,无法设置损失组合器") + return + + if hasattr(self.gate, "set_loss_combiner"): + self.gate.set_loss_combiner(combiner_name) + else: + logger.warning("当前门控层不支持动态设置损失组合器") + + def get_expert_info(self) -> Dict[str, Any]: + """ + 获取专家信息 + + Returns: + expert_info: 专家信息字典 + """ + return { + "num_experts": self.num_experts, + "num_experts_per_device": self.num_experts_per_device, + "expert_parallel_degree": self.expert_parallel_degree, + "moe_rank": self.moe_rank, + "is_parallel_enabled": self.expert_parallel_degree > 1, + "use_flexible_loss": self.use_flexible_loss, + } diff --git a/paddleformers/nn/moe_deepep/moe_communication.py b/paddleformers/nn/moe_deepep/moe_communication.py new file mode 100644 index 0000000000..15dfd62175 --- /dev/null +++ b/paddleformers/nn/moe_deepep/moe_communication.py @@ -0,0 +1,292 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import Any, List, Tuple + +import numpy as np +import paddle +import paddle.distributed as dist +from paddle import Tensor, nn +from paddle.distributed.communication.group import Group + +from ...transformers.token_dispatcher import MoEFlexTokenDispatcher + + +class MoECommunicationInterface(ABC): + """ + MoE通信接口 + + 定义EP并行通信的标准接口,支持不同的通信策略 + """ + + @abstractmethod + def forward( + self, + hidden_states: paddle.Tensor, + topk_indices: paddle.Tensor, + topk_weights: paddle.Tensor, + expert_parallel_degree: int, + moe_group: Group, + experts: nn.LayerList, + moe_rank: int, + num_experts_per_device: int, + num_experts: int, + topk: int, + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """ + EP并行通信前向传播 + + Args: + hidden_states: 输入隐藏状态 + topk_indices: TopK专家索引 + topk_weights: TopK权重 + expert_parallel_degree: 专家并行度 + moe_group: MoE通信组 + + Returns: + output: 输出隐藏状态 + aux_loss: 辅助损失 + z_loss: Z损失 + """ + pass + + +class StandardMoECommunication(nn.Layer, MoECommunicationInterface): + """ + 标准MoE通信实现 + + 基于All-to-All通信的EP并行实现 + """ + + def forward( + self, + hidden_states: paddle.Tensor, + topk_indices: paddle.Tensor, + topk_weights: paddle.Tensor, + expert_parallel_degree: int, + moe_group: Group, + experts: nn.LayerList, + moe_rank: int, + num_experts_per_device: int, + num_experts: int, + topk: int, + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """ + EP并行通信前向传播 + + Args: + hidden_states: 输入隐藏状态 + topk_indices: TopK专家索引 + topk_weights: TopK权重 + expert_parallel_degree: 专家并行度 + moe_group: MoE通信组 + + Returns: + output: 输出隐藏状态 + aux_loss: 辅助损失 + z_loss: Z损失 + """ + if expert_parallel_degree <= 1: + # 无需EP并行,直接返回 + return hidden_states + + # 计算每个专家的token数量 + cnts = paddle.zeros([topk_indices.shape[0], topk_indices.shape[1]], dtype=topk_indices.dtype) + cnts = cnts.put_along_axis(topk_indices, 1, axis=1) + tokens_per_expert = cnts.sum(axis=0) + + # 排序token + idxs = topk_indices.reshape([topk_indices.shape[0] * topk_indices.shape[1]]).argsort() + sorted_tokens = hidden_states[idxs // topk_indices.shape[1]] + tokens_per_expert = tokens_per_expert.detach() + sorted_tokens_shape = sorted_tokens.shape + + # EP并行通信 + # 计算每个EP rank的token数量 + tokens_per_ep_rank = tokens_per_expert.reshape([expert_parallel_degree, -1]).sum(axis=1) + + # 第一次All-to-All:交换token数量信息 + tokens_per_expert_group = _AllToAll.apply([tokens_per_expert.shape[0]], tokens_per_expert, group=moe_group) + + # 计算输出分割大小 + output_splits = tokens_per_expert_group.reshape([expert_parallel_degree, -1]).sum(axis=1).cpu().tolist() + input_split_sizes = tokens_per_ep_rank.cpu().tolist() + + # 第二次All-to-All:交换token数据 + gathered_tokens = _AllToAll.apply( + [tokens_per_expert_group.sum(axis=0).cpu().item(), sorted_tokens.shape[1]], + sorted_tokens, + out_split_sizes=output_splits, + in_split_sizes=input_split_sizes, + group=moe_group, + ) + + # 计算聚合后的每个专家token数量 + tokens_per_expert_post_gather = tokens_per_expert_group.reshape([expert_parallel_degree, -1]).sum(axis=0) + + # 创建聚合索引 + gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32) + s = 0 + for i, k in enumerate(tokens_per_expert_group.cpu().numpy()): + gatherd_idxs[s : s + k] = i % (tokens_per_expert_post_gather.shape[0] // expert_parallel_degree) + s += k + gatherd_idxs = gatherd_idxs.argsort() + sorted_tokens = gathered_tokens[gatherd_idxs] + tokens_per_expert = tokens_per_expert_post_gather + + # expert 计算前向 + outputs = [] + start_idx = 0 + for i, num_tokens in enumerate(tokens_per_expert): + end_idx = start_idx + num_tokens + if num_tokens == 0: + continue + expert = experts[i + moe_rank * num_experts_per_device] + tokens_for_this_expert = sorted_tokens[start_idx:end_idx] + expert_out = expert(tokens_for_this_expert) + outputs.append(expert_out) + start_idx = end_idx + outs = paddle.concat(outputs, axis=0) if len(outputs) > 0 else paddle.to_tensor(0, dtype=sorted_tokens.dtype) + + # 第三次All-to-All:将专家输出分发回原始位置 + new_x = paddle.empty_like(outs) + new_x[gatherd_idxs] = outs + gathered_tokens = _AllToAll.apply( + sorted_tokens_shape, + new_x, + out_split_sizes=input_split_sizes, + in_split_sizes=output_splits, + group=moe_group, + ) + outs = gathered_tokens + + # 最终聚合 + new_x = paddle.empty_like(outs) + new_x[idxs] = outs + final_out = ( + new_x.reshape(topk_indices.shape + [-1]) + .astype(topk_weights.dtype) + .multiply_(topk_weights.unsqueeze(-1)) + .sum(axis=1) + .astype(new_x.dtype) + ) + + return final_out + + +class DeepEPMoECommunication(nn.Layer, MoECommunicationInterface): + """ + DeepEP MoE 通信实现 + + 基于 DeepEP 通信的 EP 并行实现 + """ + + def expert_forward(self, dispatched_input, tokens_per_expert, experts, moe_rank, num_experts_per_device): + outputs = [] + tokens_per_expert = ( + tokens_per_expert.tolist() if not isinstance(tokens_per_expert, list) else tokens_per_expert + ) + # print(f"all tokens: {sum(tokens_per_expert)}, detail: {tokens_per_expert}") + chunks = paddle.split(dispatched_input, num_or_sections=tokens_per_expert, axis=0) + for i, chunk in enumerate(chunks): + chunk = chunk.contiguous() + # assert chunk.shape[0] != 0, "Cannot dispatch empty input" + expert = experts[i + moe_rank * num_experts_per_device] + outputs += [expert(chunk)] + + return paddle.concat(outputs, axis=0) + + def forward( + self, + hidden_states: paddle.Tensor, + topk_indices: paddle.Tensor, + topk_weights: paddle.Tensor, + expert_parallel_degree: int, + moe_group: Group, + experts: nn.LayerList, + moe_rank: int, + num_experts_per_device: int, + num_experts: int, + topk: int, + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: + if expert_parallel_degree <= 1: + # 无需EP并行,直接返回 + return hidden_states + token_dispatcher = MoEFlexTokenDispatcher(num_experts_per_device, topk, num_experts, moe_group) + (dispatched_input, tokens_per_expert) = token_dispatcher.token_permutation( + hidden_states, topk_indices, topk_weights + ) + expert_output = self.expert_forward(dispatched_input, tokens_per_expert) + output, _ = token_dispatcher.token_unpermutation(expert_output, None) + return output + + +class _AllToAll(paddle.autograd.PyLayer): + @staticmethod + def forward( + ctx: Any, + output_shape: List, + input: Tensor, + out_split_sizes: List = None, + in_split_sizes: List = None, + group: Group = None, + ) -> Tensor: # type: ignore + """ + All-to-all communication in the group. + Args: + ctx (Any): Context object. + output_shape (List): Output shape. + input (Tensor): Input tensor. + out_split_sizes (List): Output split sizes. + in_split_sizes (List): Input split sizes. + group (Group): The group object. + Returns: + Tensor: Output tensor. + """ + + ctx.group = group + ctx.input_shape = input.shape + ctx.out_split_sizes = out_split_sizes + ctx.in_split_sizes = in_split_sizes + + # return input + if dist.get_world_size(group) <= 1: + return input + + output = paddle.empty(output_shape, dtype=input.dtype) + task = dist.alltoall_single( + output, + input, + out_split_sizes=out_split_sizes, + in_split_sizes=in_split_sizes, + sync_op=False, + group=group, + ) + task.wait() + + return output + + @staticmethod + def backward(ctx: Any, *grad_output: Tensor) -> Tuple[Tensor]: + """ + Aggregates gradient information from all input tensors into a single tensor. + Args: + ctx (Any): The context object used to store information that needs to be passed. + *grad_output (Tensor): A list of input tensors whose gradients are to be aggregated. + Returns: + Tuple[Tensor]: A tuple containing a tensor that holds the gradients of all input tensors. + """ + # return grad_output + return _AllToAll.apply(ctx.input_shape, *grad_output, ctx.in_split_sizes, ctx.out_split_sizes, ctx.group) diff --git a/paddleformers/nn/moe_deepep/moe_config.json b/paddleformers/nn/moe_deepep/moe_config.json new file mode 100644 index 0000000000..1f4afe9d78 --- /dev/null +++ b/paddleformers/nn/moe_deepep/moe_config.json @@ -0,0 +1,23 @@ +{ + "qwen3_moe": { + "expert_parallel_degree": 1, + "gate_activation": "softmax", + "expert_activation": "silu", + "train_topk_method": "greedy", + "inference_topk_method": "greedy", + "aux_loss_weight": 0.01, + "z_loss_weight": 0.0, + "expert_dropout": 0.0, + "use_flexible_loss": false + }, + "deepdeek_v3": { + "expert_parallel_degree": 1, + "gate_activation": "softmax", + "expert_activation": "silu", + "train_topk_method": "_topk_group_limited_greedy", + "inference_topk_method": "noaux_tc", + "aux_loss_weight": 0.01, + "z_loss_weight": 0.0, + "expert_dropout": 0.0 + } +} \ No newline at end of file diff --git a/paddleformers/nn/moe_deepep/moe_expert.py b/paddleformers/nn/moe_deepep/moe_expert.py new file mode 100644 index 0000000000..15708e258f --- /dev/null +++ b/paddleformers/nn/moe_deepep/moe_expert.py @@ -0,0 +1,220 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import Dict + +import paddle +from paddle import nn + +from ...transformers import Linear, linear_utils +from ...transformers.activations import ACT2FN +from ...transformers.llama import fusion_ops +from ...transformers.refined_recompute import ( + RRColumnParallelLinear, + RRColumnSequenceParallelLinear, + RRRowParallelLinear, + RRRowSequenceParallelLinear, +) + + +class MoEExpertInterface(ABC): + """ + MoE专家网络接口 + + 定义专家网络的标准接口,支持不同的专家架构 + """ + + @abstractmethod + def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: + """ + 专家网络前向传播 + + Args: + hidden_states: 输入隐藏状态,形状: [seq_len, hidden_size] + + Returns: + output: 输出隐藏状态,形状: [seq_len, hidden_size] + """ + pass + + +class StandardMoEExpert(nn.Layer, MoEExpertInterface): + """ + 标准MoE专家网络实现 + + 支持多种专家网络架构的统一实现 + """ + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + expert_activation: str, + expert_dropout: float, + config: Dict = None, + ): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.expert_activation = expert_activation + self.expert_dropout = expert_dropout + + # 创建MLP层 + self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias_attr=False) + self.up_proj = nn.Linear(hidden_size, intermediate_size, bias_attr=False) + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias_attr=False) + + # 激活函数 + if self.expert_activation == "silu": + self.activation = paddle.nn.functional.silu + elif self.expert_activation == "gelu": + self.activation = paddle.nn.functional.gelu + elif self.expert_activation == "relu": + self.activation = paddle.nn.functional.relu + else: + self.activation = paddle.nn.functional.silu + + # Dropout + if self.expert_dropout > 0.0: + self.dropout = nn.Dropout(self.expert_dropout) + else: + self.dropout = None + + def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: + """ + 专家网络前向传播 + Args: + hidden_states: 输入隐藏状态,形状: [seq_len, hidden_size] + Returns: + output: 输出隐藏状态,形状: [seq_len, hidden_size] + """ + # 计算门控和上投影 + gate = self.gate_proj(hidden_states) + up = self.up_proj(hidden_states) + + # 应用激活函数 + intermediate = self.activation(gate) * up + + # 应用dropout + if self.dropout is not None: + intermediate = self.dropout(intermediate) + + # 下投影 + output = self.down_proj(intermediate) + + return output + + +class Qwen2MoeMLP(nn.Layer): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + expert_activation: str, + expert_dropout: float, + config: Dict = None, + ): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.expert_activation = expert_activation + self.expert_dropout = expert_dropout + + self.skip_recompute_ops = config.get("skip_recompute_ops", {}) + self.fuse_attention_ffn = config.get("skip_recompute_ops", False) + self.tensor_parallel_degree = config.get("tensor_parallel_degree", 1) + self.sequence_parallel = config.get("sequence_parallel", 1) + self.recompute = config.get("recompute", False) + self.recompute_use_reentrant = config.get("recompute_use_reentrant", False) + + if self.sequence_parallel: + ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear + RowParallelLinear = linear_utils.RowSequenceParallelLinear + + # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` + if self.recompute and not self.recompute_use_reentrant: + if self.skip_recompute_ops.get("mlp_column_ln", False): + ColumnParallelLinear = RRColumnSequenceParallelLinear + if self.skip_recompute_ops.get("mlp_row_ln", False): + RowParallelLinear = RRRowSequenceParallelLinear + else: + ColumnParallelLinear = linear_utils.ColumnParallelLinear + RowParallelLinear = linear_utils.RowParallelLinear + + # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` + if self.recompute and not self.recompute_use_reentrant: + if self.skip_recompute_ops.get("mlp_column_ln", False): + ColumnParallelLinear = RRColumnParallelLinear + if self.skip_recompute_ops.get("mlp_row_ln", False): + RowParallelLinear = RRRowParallelLinear + + if self.tensor_parallel_degree > 1: + if self.fuse_attention_ffn: + self.gate_up_fused_proj = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size * 2, + gather_output=False, + has_bias=False, + ) + else: + self.gate_proj = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + gather_output=False, + has_bias=False, + ) + self.up_proj = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + gather_output=False, + has_bias=False, + ) + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + input_is_parallel=True, + has_bias=False, + ) + else: + if self.fuse_attention_ffn: + self.gate_up_fused_proj = Linear(self.hidden_size, self.intermediate_size * 2, bias_attr=False) + else: + self.gate_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False) # w1 + self.up_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False) # w3 + self.down_proj = Linear(self.intermediate_size, self.hidden_size, bias_attr=False) # w2 + + if self.expert_activation == "silu": + self.act_fn = fusion_ops.swiglu + self.fuse_swiglu = True + else: + self.act_fn = ACT2FN[self.expert_activation] + self.fuse_swiglu = False + + def forward(self, x): + if self.fuse_attention_ffn: + x = self.gate_up_fused_proj(x) + if self.fuse_swiglu: + y = None + else: + x, y = x.chunk(2, axis=-1) + else: + x, y = self.gate_proj(x), self.up_proj(x) + + if self.fuse_swiglu: + x = self.act_fn(x, y) + else: + x = self.act_fn(x) * y + + return self.down_proj(x) diff --git a/paddleformers/nn/moe_deepep/moe_factory.py b/paddleformers/nn/moe_deepep/moe_factory.py new file mode 100644 index 0000000000..0beb786a6c --- /dev/null +++ b/paddleformers/nn/moe_deepep/moe_factory.py @@ -0,0 +1,63 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from pathlib import Path +from typing import Any, Dict + +from ...transformers.configuration_utils import PretrainedConfig +from .modular_moe_layer import ModularMoELayer + + +class QuickAccessMoEFactory: + _moe_configs: Dict[str, Dict[str, Any]] = None + + @classmethod + def _load_moe_configs(cls) -> Dict[str, Dict[str, Any]]: + if cls._moe_configs is None: + config_path = Path(__file__).parent / "moe_config.json" + with open(config_path, "r", encoding="utf-8") as f: + cls._moe_configs = json.load(f) + return cls._moe_configs + + @staticmethod + def create_from_model_name( + pretrained_config: PretrainedConfig, + ) -> ModularMoELayer: + moe_configs = QuickAccessMoEFactory._load_moe_configs() + + model_type = getattr(pretrained_config, "model_type", None) + if model_type is None: + raise ValueError("Cannot determine model type from pretrained_config") + + moe_config = moe_configs.get(model_type) + if moe_config is None: + raise ValueError(f"No MOE configuration found for model type: {model_type}") + + return ModularMoELayer( + hidden_size=pretrained_config.hidden_size, + moe_intermediate_size=pretrained_config.moe_intermediate_size, + num_experts=pretrained_config.get( + "num_experts", pretrained_config.get("n_routed_experts", pretrained_config.get("moe_num_experts", -1)) + ), + num_shared_experts=pretrained_config.get( + "n_shared_experts", pretrained_config.get("moe_num_shared_experts", 0) + ), + num_experts_per_tok=pretrained_config.get("num_experts_per_tok", pretrained_config.get("moe_k", -1)), + norm_topk_prob=pretrained_config.get("norm_topk_prob", True), + expert_activation=pretrained_config.get("hidden_act", pretrained_config.get("expert_activation", "silu")), + moe_config=moe_config, + ) + +__all__ = ["QuickAccessMoEFactory"] \ No newline at end of file diff --git a/paddleformers/nn/moe_deepep/moe_gate.py b/paddleformers/nn/moe_deepep/moe_gate.py new file mode 100644 index 0000000000..5e0cf0e128 --- /dev/null +++ b/paddleformers/nn/moe_deepep/moe_gate.py @@ -0,0 +1,747 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) Microsoft Corporation. +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import Dict, List, Optional, Tuple + +import paddle +import paddle.distributed as dist +import paddle.nn as nn +import paddle.nn.functional as F + +from ...utils.log import logger +from .moe_loss import LossCombiner, LossConfig, LossFunction, LossRegistry, LossType + + +class MoEGateMixin: + def gate_score_func(self, logits: paddle.Tensor) -> paddle.Tensor: + # [..., hidden_dim] -> [..., num_experts] + with paddle.amp.auto_cast(False): + scoring_func = getattr(self, "scoring_func", None) + if scoring_func == "softmax": + scores = F.softmax(logits.cast("float32"), axis=-1) + elif scoring_func == "sigmoid": + scores = F.sigmoid(logits.cast("float32")) + elif scoring_func == "tanh": + scores = F.tanh(logits.cast("float32")) + elif scoring_func == "relu": + scores = F.relu(logits.cast("float32")) + elif scoring_func == "gelu": + scores = F.gelu(logits.cast("float32")) + elif scoring_func == "leaky_relu": + scores = F.leaky_relu(logits.cast("float32")) + else: + logger.warning_once( + f"insupportable scoring function for MoE gating: {scoring_func}, use softmax instead" + ) + scores = F.softmax(logits.cast("float32"), axis=-1) + return scores + + def gumbel_rsample(self, logits: paddle.Tensor) -> paddle.Tensor: + gumbel = paddle.distribution.gumbel.Gumbel(0, 1) + return gumbel.rsample(logits.shape) + + def uniform_sample(self, logits: paddle.Tensor) -> paddle.Tensor: + uniform = paddle.distribution.uniform.Uniform(0, 1) + return uniform.sample(logits.shape) + + @paddle.no_grad() + def _one_hot_to_float(self, x, num_classes): + if x.dtype not in (paddle.int32, paddle.int64): + x = paddle.cast(x, paddle.int64) + return F.one_hot(x, num_classes=num_classes).cast(paddle.get_default_dtype()) + + @paddle.no_grad() + def _one_hot_to_int64(self, x, num_classes): + if x.dtype not in (paddle.int32, paddle.int64): + x = paddle.cast(x, paddle.int64) + return F.one_hot(x, num_classes=num_classes).cast(paddle.int64) + + @paddle.no_grad() + def _capacity( + self, + gates: paddle.Tensor, + capacity_factor: float, + max_capacity: int, + min_capacity: int, + ) -> paddle.Tensor: + """Calculate the capacity for each expert based on the gates and capacity factor. + + Args: + gates (paddle.Tensor): A tensor of shape [num_tokens, num_experts] representing the probability distribution + over experts for each token. + capacity_factor (float): A scalar float value representing the capacity factor for each expert. + min_capacity (int): A scalar integer value representing the minimum capacity for each expert. + + Returns: + int: A tensor value representing the calculated capacity for each expert. + """ + assert gates.ndim == 2, f"gates should be 2D, but got {gates.ndim}, {gates.shape}" + # gates has shape of SE + num_tokens = gates.shape[0] + num_experts = gates.shape[1] + capacity = int((num_tokens // num_experts) * capacity_factor) + if capacity < min_capacity: + capacity = min_capacity + if capacity > max_capacity: + capacity = max_capacity + assert capacity > 0, f"requires capacity > 0, capacity_factor: {capacity_factor}, input_shape: {gates.shape}" + + return capacity + + def _cal_aux_loss(self, gates, mask): + """ + Calculate auxiliary loss + + Args: + gates (paddle.Tensor): Represents the output probability of each expert. The shape is [batch_size, num_experts] + mask (paddle.Tensor): Represents whether each sample belongs to a certain expert. The shape is [batch_size, num_experts] + + Returns: + paddle.Tensor: The value of auxiliary loss. + + """ + # TODO: @DrownFish19 update aux_loss for Qwen2MoE and DeepSeekV2&V3 + me = paddle.mean(gates, axis=0) + ce = paddle.mean(mask.cast("float32"), axis=0) + if self.global_aux_loss: + me_list, ce_list = [], [] + dist.all_gather(me_list, me, group=self.group) + dist.all_gather(ce_list, ce, group=self.group) + + me_list[self.rank] = me + ce_list[self.rank] = ce + me = paddle.stack(me_list).mean(0) + ce = paddle.stack(ce_list).mean(0) + aux_loss = paddle.sum(me * ce) * float(self.num_experts) + return aux_loss + + def _cal_seq_aux_loss(self, gates, num_experts_per_tok, topk_idx) -> paddle.Tensor: + """ + Calculate sequence auxiliary loss. + + Args: + logits (paddle.Tensor): Model output. + + Returns: + paddle.Tensor: The value of sequence auxiliary loss. + """ + batch_size, seq_len, _ = gates.shape + ce = paddle.zeros([batch_size, self.num_experts]) + topk_idx = topk_idx.reshape([batch_size, -1]) + ce.put_along_axis_( + indices=topk_idx, values=paddle.ones([batch_size, seq_len * num_experts_per_tok]), axis=1, reduce="add" + ) + ce = ce / (seq_len * num_experts_per_tok / self.num_experts) + aux_loss = (ce * paddle.mean(gates, axis=1)).sum(axis=1).mean() + return aux_loss + + def _cal_z_loss(self, logits) -> paddle.Tensor: + """ + Calculate the z loss. + + Args: + logits (paddle.Tensor): Model output. The shape is [batch_size, num_experts]. + + Returns: + paddle.Tensor: The z loss value. + """ + l_zloss = paddle.logsumexp(logits, axis=1).square().mean() + return l_zloss + + def _cal_orthogonal_loss(self) -> paddle.Tensor: + """Gate weight orthogonal loss. + + Returns: + Paddle.Tensor: orthogonal loss + """ + weight = F.normalize(self.weight, axis=0) + orthogonal_loss = paddle.mean(paddle.square(paddle.matmul(weight.T, weight) - paddle.eye(self.num_experts))) + return orthogonal_loss + + def _priority(self, topk_idx: paddle.Tensor, capacity: int) -> paddle.Tensor: + """_summary_ + The priority is the cumulative sum of the expert indices. + + This method is used in hunyuan model + Args: + topk_idx (paddle.Tensor): [batch_size * seq_len, topk] + + Returns: + paddle.Tensor: cumsum locations + """ + _, k = topk_idx.shape + # Shape: [seq_len * k] + chosen_expert = topk_idx.reshape([-1]) + # Shape: [seq_len * k, num_experts]. + token_priority = F.one_hot(chosen_expert, self.num_experts).cast(paddle.int32) + token_priority = paddle.logical_and(token_priority > 0, token_priority.cumsum(axis=0) <= capacity) + # Shape: [seq_len, num_experts]. + token_priority = token_priority.reshape([-1, k, self.num_experts]).sum(axis=1) + + return (token_priority > 0.0).astype("float32") + + def _probs_drop_policy( + self, + scores: torch.Tensor, + capacity: int, + ) -> torch.Tensor: + """ + Implements the Probability-based (Probs) drop policy to enforce expert capacity. + + A token is assigned (mask value 1.0) to an expert if: + 1. It chose that expert (score > 0). (Implicitly handled by input scores). + 2. Its score for that expert is among the top 'capacity' scores for that expert. + + Args: + scores (torch.Tensor): [num_tokens, num_total_experts]. + This should already contain zeros for non-selected + experts (i.e., the result of top-K gating). + capacity (int): The maximum number of tokens any single expert can handle. + (Not strictly used here, but good practice to include). + + Returns: + torch.Tensor: [num_tokens, num_total_experts] boolean mask (converted to float). + 1.0 = Assigned and within capacity. 0.0 = Dropped or unassigned. + """ + num_tokens, num_experts = scores.shape + + # --- Step 1: Find the 'capacity' best tokens for *each* expert --- + + # Use torch.topk along dim=0 (the token dimension) to find the indices + # of the tokens that have the highest scores for each expert (column). + # Since 'scores' has shape [Tokens, Experts], dim=0 returns the token indices. + + # topk_token_indices has shape [capacity, num_total_experts] + # It tells us WHICH tokens (row indices) are prioritized by capacity. + + # We use min(num_tokens, capacity) just in case there are fewer tokens than capacity. + k_to_use = min(num_tokens, capacity) + + # We only care about the indices of the selected tokens + _, topk_token_indices = paddle.topk( + scores, + k=k_to_use, + dim=0, + sorted=True # Sorted=True is usually faster, but we only use the indices. + ) + + # --- Step 2: Create the final assignment mask using scatter --- + + # Initialize the mask to all zeros (tokens are initially dropped/unassigned). + # We use boolean type for efficient scattering, then convert to float later. + final_mask = paddle.zeros(num_tokens, num_experts, dtype=paddle.bool) + + # 2a. Create the column indices for the assignment. + # We need a tensor of shape [k_to_use, num_experts] where each row is [0, 1, 2, ..., num_experts-1]. + col_indices = paddle.arange(num_experts).unsqueeze(0).expand_as(topk_token_indices) + + # 2b. Flatten the row (token) and column (expert) indices for advanced indexing. + token_indices_flat = topk_token_indices.flatten() + col_indices_flat = col_indices.flatten() + + # 2c. Use advanced indexing to set the mask positions to True. + # This sets mask[token_index, expert_index] = True for all prioritized tokens. + final_mask[token_indices_flat, col_indices_flat] = True + + # --- Step 3: Ensure only originally selected tokens are kept --- + + # Since torch.topk can pick up tokens with score 0 if num_tokens < capacity, + # we must ensure that we only keep tokens that had a positive score initially. + # This step implicitly cleans up any spurious assignments made by topk on zero scores. + + token_priority_mask = final_mask.float() * (scores > 0).float() + + return token_priority_mask + + def _topk_greedy(self, scores: paddle.Tensor, k: int) -> Tuple[paddle.Tensor, paddle.Tensor]: + """_summary_ + + Args: + scores (paddle.Tensor): [bsz*seq_len, n_experts] + k (int): select the top k experts + + Returns: + Tuple[paddle.Tensor, paddle.Tensor]: topk_weight, topk_idx + topk_weight: [bsz*seq_len, k] + topk_idx: [bsz*seq_len, k] + """ + topk_weight, topk_idx = paddle.topk(scores, k=k, axis=-1, sorted=True) + + return topk_weight, topk_idx + + def _topk_group_limited_greedy( + self, scores: paddle.Tensor, k: int, n_group: int, topk_group: int + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + """_summary_ + + Args: + scores (paddle.Tensor): [bsz*seq_len, n_experts] + k (int): select the top k experts in each group + n_groups (int): the number of groups for all experts + topk_group (int): the number of groups selected + + Returns: + Tuple[paddle.Tensor, paddle.Tensor]: topk_weight, topk_idx + topk_weight: [bsz*seq_len, k] + topk_idx: [bsz*seq_len, k] + + Note: the group size is normal greater than the number of k + """ + bsz_seq_len, n_experts = scores.shape + assert n_experts % n_group == 0, "n_experts must be divisible by n_groups" + + group_scores = scores.reshape([0, n_group, -1]).max(axis=-1) # [n, n_group] + group_idx = paddle.topk(group_scores, k=topk_group, axis=-1, sorted=True)[1] # [n, top_k_group] + group_mask = paddle.zeros_like(group_scores).put_along_axis(group_idx, paddle.to_tensor(1.0), axis=-1) # fmt:skip + score_mask = ( + group_mask.unsqueeze(-1).expand([bsz_seq_len, n_group, n_experts // n_group]).reshape([bsz_seq_len, -1]) + ) # [n, e] + tmp_scores = scores * score_mask # [n, e] + topk_weight, topk_idx = paddle.topk(tmp_scores, k=k, axis=-1, sorted=True) + + return topk_weight, topk_idx + + def _topk_noaux_tc( + self, scores: paddle.Tensor, k: int, n_group: int, topk_group: int + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + """_summary_ + + Args: + scores (paddle.Tensor): [bsz*seq_len, n_experts] + k (int): select the top k experts in each group + n_groups (int): the number of groups for all experts + topk_group (int): the number of groups selected + + Returns: + Tuple[paddle.Tensor, paddle.Tensor]: topk_weight, topk_idx + topk_weight: [bsz*seq_len, k] + topk_idx: [bsz*seq_len, k] + + Note: the group size is normal greater than the number of k + """ + bsz_seq_len, n_experts = scores.shape + assert n_experts % n_group == 0, "n_experts must be divisible by n_groups" + + assert self.e_score_correction_bias is not None, "e_score_correction_bias is None" + scores_for_choice = scores.reshape([bsz_seq_len, -1]) + self.e_score_correction_bias.detach().unsqueeze(0) + group_scores = ( + scores_for_choice.reshape([bsz_seq_len, self.n_group, -1]).topk(2, axis=-1)[0].sum(axis=-1) + ) # fmt:skip [n, n_group] + group_idx = paddle.topk(group_scores, k=topk_group, axis=-1, sorted=True)[1] # [n, top_k_group] + group_mask = paddle.zeros_like(group_scores).put_along_axis(group_idx, paddle.to_tensor(1.0, dtype="float32"), axis=-1) # fmt:skip + score_mask = ( + group_mask.unsqueeze(-1).expand([bsz_seq_len, n_group, n_experts // n_group]).reshape([bsz_seq_len, -1]) + ) # [n, e] + tmp_scores = scores_for_choice * score_mask # [n, e] + topk_weight, topk_idx = paddle.topk(tmp_scores, k=k, axis=-1, sorted=True) + topk_weight = scores.take_along_axis(topk_idx, axis=1) if not self.training else topk_weight + + return topk_weight, topk_idx + + +class PretrainedMoEGate(nn.Layer, MoEGateMixin): + def __init__( + self, + num_experts: int, + expert_hidden_size: int, + drop_tokens: bool, + topk_method: str, + num_experts_per_tok: int, + norm_topk_prob: bool, + moe_config: Dict, + ): + super(PretrainedMoEGate, self).__init__() + + self.num_experts = num_experts + self.expert_hidden_size = expert_hidden_size + self.drop_tokens = drop_tokens + # Qwen2MoE: greedy + # DeepSeekV2&V3: group_limited_greedy for training, and noaux_tc for inference + self.topk_method = topk_method + self.num_experts_per_tok = num_experts_per_tok + self.norm_topk_prob = norm_topk_prob + # force keep in float32 when using amp + self._cast_to_low_precision = False + + self.scoring_func = moe_config.get("scoring_func", "sigmoid") + self.capacity_factor = moe_config.get("capacity_factor", 1.0) + self.eval_capacity_factor = moe_config.get("eval_capacity_factor", 1.0) + self.min_capacity = moe_config.get("min_capacity", 1) + self.max_capacity = moe_config.get("max_capacity", pow(2, 32)) + self.group = moe_config.get("group", None) + self.global_aux_loss = moe_config.get("global_aux_loss", False) + self.use_rts = moe_config.get("use_rts", True) + self.top2_2nd_expert_sampling = moe_config.get("top2_2nd_expert_sampling", True) + self.drop_policy = moe_config.get("drop_policy", "probs") + self.n_group = moe_config.get("n_group", 1) # for group_limited_greedy + self.topk_group = moe_config.get("topk_group", 1) # for group_limited_greedy + self.routed_scaling_factor = moe_config.get("routed_scaling_factor", 1.0) + self.seq_aux = moe_config.get("seq_aux", False) + + if self.global_aux_loss: + assert self.group is not None, "group is required when global_aux_loss is True" + self.rank = dist.get_rank(self.group) + + # weight of hidden_state -> score + self.weight = paddle.create_parameter( + shape=[self.expert_hidden_size, self.num_experts], + dtype="bfloat16", + default_initializer=paddle.nn.initializer.Uniform(), + ) + + def forward( + self, + gates: paddle.Tensor, + ) -> Tuple[int, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: + return self.topkgating(gates) + + def topkgating( + self, + gates: paddle.Tensor, + ) -> Tuple[int, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """Implements TopKGating on logits.""" + # print("forward input gates: ", gates.shape, gates.dtype) # [8, 97, 2048] paddle.bfloat16 + + batch_size, seq_len, d_model = gates.shape + gates_ori = gates + gates = gates.reshape([-1, d_model]) + + # 将 hidden_state 转换成 score(每个 token 对每个专家的偏好分数) + with paddle.amp.auto_cast(False): + hidden_states = gates.cast(self.weight.dtype) + logits = F.linear(hidden_states.cast("float32"), self.weight.cast("float32")) + gates = self.gate_score_func(logits=logits) + + l_zloss = self._cal_z_loss(gates) + + # get topk gates + if self.topk_method == "greedy": + top_gate, top_idx = self._topk_greedy(gates, k=self.num_experts_per_tok) + elif self.topk_method == "group_limited_greedy": + top_gate, top_idx = self._topk_group_limited_greedy( + gates, k=self.num_experts_per_tok, n_group=self.n_group, topk_group=self.topk_group + ) + elif self.topk_method == "noaux_tc": + top_gate, top_idx = self._topk_noaux_tc( + gates, k=self.num_experts_per_tok, n_group=self.n_group, topk_group=self.topk_group + ) + else: + raise NotImplementedError(f"Invalid topk_method: {self.topk_method}") + + # norm gate to sum 1 + if self.num_experts_per_tok > 1 and self.norm_topk_prob: + denominator = top_gate.sum(axis=-1, keepdim=True) + 1e-20 + top_gate = top_gate / denominator + top_gate = top_gate * self.routed_scaling_factor + + # get topk mask + # print("gates: ", gates.shape, gates.dtype) # [776, 2048] paddle.bfloat16 + # print("top_gate: ", top_gate.shape, top_gate.dtype) # [776, 8] paddle.float32 + # print("top_idx: ", top_idx.shape, top_idx.dtype) # [776, 8] paddle.int64 + mask = paddle.zeros_like(gates).put_along_axis(top_idx, paddle.to_tensor(1.0, dtype=gates.dtype), axis=1) + + if self.seq_aux: + l_aux = self._cal_seq_aux_loss(gates_ori, self.num_experts_per_tok, top_idx) + else: + l_aux = self._cal_aux_loss(gates, mask) + + exp_counts = paddle.sum(mask.cast(paddle.int64), axis=0) + + if self.drop_tokens: + # Calculate configured capacity and remove locations outside capacity from mask + capacity = self._capacity( + gates, + self.capacity_factor * self.num_experts_per_tok, + self.max_capacity, + self.min_capacity, + ) + + # update mask and locations by capacity + if self.drop_policy == "probs": + topk_masked_gates = paddle.zeros_like(gates).put_along_axis(top_idx, top_gate, axis=1) + # print("--topk_masked_gates: ", topk_masked_gates) + # capacity_probs, capacity_indices = paddle.topk(topk_masked_gates, k=capacity, axis=0, sorted=False) + # print("--capacity_probs: ", capacity_probs) + # print("--capacity_indices: ", capacity_indices) + # token_priority = self._priority(capacity_indices, capacity) + # print("--token_priority: ", token_priority) + + token_priority = self._probs_drop_policy(topk_masked_gates, capacity) + # print("--token_priority: ", token_priority) + + elif self.drop_policy == "position": + token_priority = self._priority(top_idx, capacity) + else: + raise ValueError(f"Invalid drop_policy: {self.drop_policy}") + else: + # Do not drop tokens - set capacity according to current expert assignments + local_capacity = paddle.max(exp_counts) + if self.group is not None: + dist.all_reduce(local_capacity, op=dist.ReduceOp.MAX, group=self.group) + capacity = int(local_capacity) + token_priority = self._priority(top_idx, capacity) + + # normalize gates + gates_masked = gates * mask + + # if self.training: + gates_s = paddle.sum(gates_masked, axis=-1, keepdim=True) + denom_s = paddle.clip(gates_s, min=paddle.finfo(gates_masked.dtype).eps) + if self.norm_topk_prob: + gates_masked = gates_masked / denom_s + gates_masked *= self.routed_scaling_factor + + return ( + capacity, + gates_masked.take_along_axis(top_idx, axis=-1), + top_idx, + token_priority.take_along_axis(top_idx, axis=-1), + l_aux, + l_zloss, + ) + + +class FlexibleMoEGate(nn.Layer, MoEGateMixin): + """自定义损失函数的 MoE Gate""" + + def __init__( + self, + num_experts: int, + expert_hidden_size: int, + drop_tokens: bool, + topk_method: str, + num_experts_per_tok: int, + norm_topk_prob: bool, + moe_config: Dict, + loss_registry, + loss_configs: Optional[List[LossConfig]] = None, + loss_combiner_name: str = "weighted_sum", + **kwargs + ): + super(FlexibleMoEGate, self).__init__() + + self.num_experts = num_experts + self.expert_hidden_size = expert_hidden_size + self.drop_tokens = drop_tokens + # Qwen2MoE: greedy + # DeepSeekV2&V3: group_limited_greedy for training, and noaux_tc for inference + self.topk_method = topk_method + self.num_experts_per_tok = num_experts_per_tok + self.norm_topk_prob = norm_topk_prob + # force keep in float32 when using amp + self._cast_to_low_precision = False + + self.scoring_func = moe_config.get("scoring_func", "sigmoid") + self.capacity_factor = moe_config.get("capacity_factor", 1.0) + self.eval_capacity_factor = moe_config.get("eval_capacity_factor", 1.0) + self.min_capacity = moe_config.get("min_capacity", 1) + self.max_capacity = moe_config.get("max_capacity", pow(2, 32)) + self.group = moe_config.get("group", None) + self.global_aux_loss = moe_config.get("global_aux_loss", False) + self.use_rts = moe_config.get("use_rts", True) + self.top2_2nd_expert_sampling = moe_config.get("top2_2nd_expert_sampling", True) + self.drop_policy = moe_config.get("drop_policy", "probs") + self.n_group = moe_config.get("n_group", 1) # for group_limited_greedy + self.topk_group = moe_config.get("topk_group", 1) # for group_limited_greedy + self.routed_scaling_factor = moe_config.get("routed_scaling_factor", 1.0) + self.seq_aux = moe_config.get("seq_aux", False) + + # 损失相关配置 + self.loss_registry = loss_registry + self.loss_configs = loss_configs or [ + LossConfig("auxiliary", LossType.AUXILIARY, weight=0.01), + LossConfig("z_loss", LossType.Z_LOSS, weight=0.0), + ] + + # 设置损失组合器 + self.loss_combiner = loss_registry.get_combiner(loss_combiner_name) + if self.loss_combiner is None: + logger.warning(f"未找到损失组合器: {loss_combiner_name}, 使用默认组合器") + self.loss_combiner = loss_registry.get_combiner("weighted_sum") + + # 初始化损失存储 + self.current_losses = {} + self.total_loss = paddle.to_tensor(0.0) + + # weight of hidden_state -> score + self.weight = paddle.create_parameter( + shape=[self.expert_hidden_size, self.num_experts], + dtype="bfloat16", + default_initializer=paddle.nn.initializer.Uniform(), + ) + + def _gumbel_softmax(self, logits: paddle.Tensor, temperature: float = 1.0) -> paddle.Tensor: + """Gumbel-Softmax采样""" + gumbel_noise = -paddle.log(-paddle.log(paddle.rand_like(logits) + 1e-8) + 1e-8) + return paddle.nn.functional.softmax((logits + gumbel_noise) / temperature) + + def forward( + self, + gates: paddle.Tensor, + ) -> Tuple[int, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """Implements TopKGating on logits.""" + batch_size, seq_len, d_model = gates.shape + gates_ori = gates + gates = gates.reshape([-1, d_model]) + + # 将 hidden_state 转换成 score(每个 token 对每个专家的偏好分数) + with paddle.amp.auto_cast(False): + hidden_states = gates.cast(self.weight.dtype) + logits = F.linear(hidden_states.cast("float32"), self.weight.cast("float32")) + gates = self.gate_score_func(logits=logits) + + # get topk gates + if self.topk_method == "greedy": + top_gate, top_idx = self._topk_greedy(gates, k=self.num_experts_per_tok) + elif self.topk_method == "group_limited_greedy": + top_gate, top_idx = self._topk_group_limited_greedy( + gates, k=self.num_experts_per_tok, n_group=self.n_group, topk_group=self.topk_group + ) + elif self.topk_method == "noaux_tc": + top_gate, top_idx = self._topk_noaux_tc( + gates, k=self.num_experts_per_tok, n_group=self.n_group, topk_group=self.topk_group + ) + # norm gate to sum 1 + if self.num_experts_per_tok > 1 and self.norm_topk_prob: + denominator = top_gate.sum(axis=-1, keepdim=True) + 1e-20 + top_gate = top_gate / denominator + top_gate = top_gate * self.routed_scaling_factor + + # get topk mask + mask = paddle.zeros_like(gates).put_along_axis(top_idx, paddle.to_tensor(1.0, dtype="float32"), axis=1) + + # 计算所有损失函数 + self.current_losses = {} + config_dict = {config.name: config for config in self.loss_configs} + + for config in self.loss_configs: + if not config.enabled: + continue + + loss_func = self.loss_registry.get_loss(config.name) + if loss_func is None: + logger.warning(f"未找到损失函数: {config.name}") + continue + + try: + self.loss_value = loss_func( + routing_weights=top_gate, + selected_experts=top_idx, + gate_logits=gates, + num_experts=self.num_experts, + batch_size=batch_size, + seq_len=seq_len, + **config.params, + ) + self.current_losses[config.name] = self.loss_value + except Exception as e: + logger.warning(f"计算损失函数 {config.name} 时出错: {e}") + self.current_losses[config.name] = paddle.to_tensor(0.0) + + # 组合损失 + self.total_loss = self.loss_combiner(self.current_losses, config_dict) + + exp_counts = paddle.sum(mask.cast(paddle.int64), axis=0) + + if self.drop_tokens: + # Calculate configured capacity and remove locations outside capacity from mask + capacity = self._capacity( + gates, + self.capacity_factor * self.num_experts_per_tok, + self.max_capacity, + self.min_capacity, + ) + + # update mask and locations by capacity + if self.drop_policy == "probs": + topk_masked_gates = paddle.zeros_like(gates).put_along_axis(top_idx, top_gate, axis=1) + capacity_probs, capacity_indices = paddle.topk(topk_masked_gates, k=capacity, axis=0, sorted=False) + token_priority = self._priority(capacity_indices, capacity) + + elif self.drop_policy == "position": + token_priority = self._priority(top_idx, capacity) + else: + raise ValueError(f"Invalid drop_policy: {self.drop_policy}") + else: + # Do not drop tokens - set capacity according to current expert assignments + local_capacity = paddle.max(exp_counts) + if self.group is not None: + dist.all_reduce(local_capacity, op=dist.ReduceOp.MAX, group=self.group) + capacity = int(local_capacity) + token_priority = self._priority(top_idx, capacity) + + # normalize gates + # gates_masked is equal to top_gate. + gates_masked = gates * mask + # if self.training: + gates_s = paddle.sum(gates_masked, axis=-1, keepdim=True) + denom_s = paddle.clip(gates_s, min=paddle.finfo(gates_masked.dtype).eps) + if self.norm_topk_prob: + gates_masked = gates_masked / denom_s + gates_masked *= self.routed_scaling_factor + + return ( + capacity, + gates_masked.take_along_axis(top_idx, axis=-1), + top_idx, + token_priority.take_along_axis(top_idx, axis=-1), + self.total_loss, + paddle.to_tensor(0.0), + ) + + def add_loss_config(self, config: LossConfig): + """添加损失配置""" + self.loss_configs.append(config) + logger.info(f"添加损失配置: {config.name}, 权重: {config.weight}") + + def remove_loss_config(self, name: str): + """移除损失配置""" + self.loss_configs = [config for config in self.loss_configs if config.name != name] + logger.info(f"移除损失配置: {name}") + + def update_loss_weights(self, weights: Dict[str, float]): + """更新损失权重""" + for config in self.loss_configs: + if config.name in weights: + config.weight = weights[config.name] + logger.info(f"更新损失权重: {weights}") + + def set_loss_combiner(self, combiner_name: str): + """设置损失组合器""" + combiner = self.loss_registry.get_combiner(combiner_name) + if combiner is not None: + self.loss_combiner = combiner + logger.info(f"设置损失组合器: {combiner_name}") + else: + logger.warning(f"未找到损失组合器: {combiner_name}") + + def get_auxiliary_loss(self) -> paddle.Tensor: + """获取辅助损失(兼容性方法)""" + return self.current_losses.get("auxiliary", paddle.to_tensor(0.0)) + + def get_z_loss(self) -> paddle.Tensor: + """获取Z损失(兼容性方法)""" + return self.current_losses.get("z_loss", paddle.to_tensor(0.0)) + + def get_all_losses(self) -> Dict[str, paddle.Tensor]: + """获取所有损失""" + return self.current_losses.copy() + + def get_total_loss(self) -> paddle.Tensor: + """获取总损失""" + return self.total_loss diff --git a/paddleformers/nn/moe_deepep/moe_loss.py b/paddleformers/nn/moe_deepep/moe_loss.py new file mode 100644 index 0000000000..d8462496f1 --- /dev/null +++ b/paddleformers/nn/moe_deepep/moe_loss.py @@ -0,0 +1,257 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +from abc import ABC, abstractmethod +from dataclasses import dataclass +from enum import Enum +from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Type, Union + +import numpy as np +import paddle +import paddle.distributed as dist +from paddle import Tensor, nn +from paddle.distributed import fleet +from paddle.distributed.communication.group import Group + +logger = logging.getLogger(__name__) + + +class LossType(Enum): + """损失函数类型枚举""" + + AUXILIARY = "auxiliary" + Z_LOSS = "z_loss" + ENTROPY = "entropy" + SPARSITY = "sparsity" + DIVERSITY = "diversity" + CUSTOM = "custom" + + +@dataclass +class LossConfig: + """损失函数配置""" + + name: str + loss_type: LossType + weight: float = 0.0 + enabled: bool = True + params: Dict[str, Any] = None + + def __post_init__(self): + if self.params is None: + self.params = {} + + +class LossFunction(Protocol): + """损失函数协议接口""" + + def __call__( + self, + routing_weights: paddle.Tensor, + selected_experts: paddle.Tensor, + gate_logits: Optional[paddle.Tensor] = None, + **kwargs + ) -> paddle.Tensor: + """计算损失函数""" + ... + + +class LossCombiner(Protocol): + """损失组合器协议接口""" + + def __call__(self, losses: Dict[str, paddle.Tensor], configs: Dict[str, LossConfig]) -> paddle.Tensor: + """组合多个损失函数""" + ... + + +class LossRegistry: + """损失函数注册器""" + + def __init__(self): + self._loss_functions: Dict[str, LossFunction] = {} + self._loss_combiners: Dict[str, LossCombiner] = {} + self._register_default_losses() + self._register_default_combiners() + + def _register_default_losses(self): + """注册默认损失函数""" + self.register_loss("auxiliary", self._auxiliary_loss) + self.register_loss("z_loss", self._z_loss) + self.register_loss("entropy", self._entropy_loss) + self.register_loss("sparsity", self._sparsity_loss) + self.register_loss("diversity", self._diversity_loss) + + def _register_default_combiners(self): + """注册默认损失组合器""" + self.register_combiner("weighted_sum", self._weighted_sum_combiner) + self.register_combiner("adaptive_sum", self._adaptive_sum_combiner) + self.register_combiner("geometric_mean", self._geometric_mean_combiner) + + def register_loss(self, name: str, loss_func: LossFunction): + """注册损失函数""" + self._loss_functions[name] = loss_func + logger.info(f"注册损失函数: {name}") + + def register_combiner(self, name: str, combiner: LossCombiner): + """注册损失组合器""" + self._loss_combiners[name] = combiner + logger.info(f"注册损失组合器: {name}") + + def get_loss(self, name: str) -> Optional[LossFunction]: + """获取损失函数""" + return self._loss_functions.get(name) + + def get_combiner(self, name: str) -> Optional[LossCombiner]: + """获取损失组合器""" + return self._loss_combiners.get(name) + + def list_losses(self) -> List[str]: + """列出所有损失函数""" + return list(self._loss_functions.keys()) + + def list_combiners(self) -> List[str]: + """列出所有损失组合器""" + return list(self._loss_combiners.keys()) + + # 默认损失函数实现 + def _auxiliary_loss( + self, + routing_weights: paddle.Tensor, + selected_experts: paddle.Tensor, + gate_logits: Optional[paddle.Tensor] = None, + **kwargs + ) -> paddle.Tensor: + """标准辅助损失(负载均衡损失)""" + num_experts = kwargs.get("num_experts", selected_experts.max().item() + 1) + expert_usage = paddle.zeros([num_experts], dtype=routing_weights.dtype) + + for i in range(selected_experts.shape[0]): + for j in range(selected_experts.shape[1]): + expert_idx = selected_experts[i, j].item() + expert_usage[expert_idx] += routing_weights[i, j] + + expert_usage = expert_usage / selected_experts.shape[0] + aux_loss = paddle.sum(expert_usage * paddle.log(expert_usage + 1e-8)) + return aux_loss + + def _z_loss( + self, + routing_weights: paddle.Tensor, + selected_experts: paddle.Tensor, + gate_logits: Optional[paddle.Tensor] = None, + **kwargs + ) -> paddle.Tensor: + """标准Z损失""" + if gate_logits is None: + return paddle.to_tensor(0.0) + return paddle.sum(gate_logits**2) + + def _entropy_loss( + self, + routing_weights: paddle.Tensor, + selected_experts: paddle.Tensor, + gate_logits: Optional[paddle.Tensor] = None, + **kwargs + ) -> paddle.Tensor: + """熵损失 - 鼓励路由权重的多样性""" + return -paddle.sum(routing_weights * paddle.log(routing_weights + 1e-8)) + + def _sparsity_loss( + self, + routing_weights: paddle.Tensor, + selected_experts: paddle.Tensor, + gate_logits: Optional[paddle.Tensor] = None, + **kwargs + ) -> paddle.Tensor: + """稀疏性损失 - 鼓励专家选择的稀疏性""" + num_experts = kwargs.get("num_experts", selected_experts.max().item() + 1) + expert_usage = paddle.zeros([num_experts]) + + for i in range(selected_experts.shape[0]): + for j in range(selected_experts.shape[1]): + expert_idx = selected_experts[i, j].item() + expert_usage[expert_idx] += 1 + + return paddle.sum(paddle.abs(expert_usage)) + + def _diversity_loss( + self, + routing_weights: paddle.Tensor, + selected_experts: paddle.Tensor, + gate_logits: Optional[paddle.Tensor] = None, + **kwargs + ) -> paddle.Tensor: + """多样性损失 - 鼓励专家选择的多样性""" + num_experts = kwargs.get("num_experts", selected_experts.max().item() + 1) + expert_counts = paddle.zeros([num_experts]) + + for i in range(selected_experts.shape[0]): + for j in range(selected_experts.shape[1]): + expert_idx = selected_experts[i, j].item() + expert_counts[expert_idx] += 1 + + uniform_dist = paddle.ones_like(expert_counts) / expert_counts.shape[0] + diversity_loss = paddle.nn.functional.kl_div( + paddle.log(expert_counts + 1e-8), paddle.log(uniform_dist + 1e-8), reduction="sum" + ) + return diversity_loss + + # 默认损失组合器实现 + def _weighted_sum_combiner( + self, losses: Dict[str, paddle.Tensor], configs: Dict[str, LossConfig] + ) -> paddle.Tensor: + """加权求和组合""" + combined_loss = paddle.to_tensor(0.0) + for name, loss_value in losses.items(): + config = configs.get(name) + if config and config.enabled: + combined_loss += config.weight * loss_value + return combined_loss + + def _adaptive_sum_combiner( + self, losses: Dict[str, paddle.Tensor], configs: Dict[str, LossConfig] + ) -> paddle.Tensor: + """自适应加权求和组合""" + combined_loss = paddle.to_tensor(0.0) + enabled_losses = [ + loss for name, loss in losses.items() if configs.get(name, LossConfig("", LossType.CUSTOM)).enabled + ] + + if len(enabled_losses) > 1: + loss_std = paddle.std(paddle.stack(enabled_losses)) + else: + loss_std = paddle.to_tensor(1.0) + + adaptation_factor = 0.1 + for name, loss_value in losses.items(): + config = configs.get(name) + if config and config.enabled: + adaptive_weight = config.weight * (1 + adaptation_factor * loss_std) + combined_loss += adaptive_weight * loss_value + + return combined_loss + + def _geometric_mean_combiner( + self, losses: Dict[str, paddle.Tensor], configs: Dict[str, LossConfig] + ) -> paddle.Tensor: + """几何平均组合""" + combined_loss = paddle.to_tensor(1.0) + for name, loss_value in losses.items(): + config = configs.get(name) + if config and config.enabled and config.weight > 0: + combined_loss *= (loss_value + 1e-8) ** config.weight + return combined_loss diff --git a/paddleformers/nn/moe_deepep/moe_loss_instance.py b/paddleformers/nn/moe_deepep/moe_loss_instance.py new file mode 100644 index 0000000000..3d045ab8ea --- /dev/null +++ b/paddleformers/nn/moe_deepep/moe_loss_instance.py @@ -0,0 +1,57 @@ + + +import paddle +from typing import Dict, Optional + +from .moe_loss import LossCombiner, LossConfig, LossFunction, LossRegistry, LossType + +# 全局损失注册器实例. 使用函数延迟创建实例 +def get_global_loss_registry(): + if not hasattr(get_global_loss_registry, '_instance'): + get_global_loss_registry._instance = LossRegistry() + # 注册损失函数到全局注册器 + get_global_loss_registry._instance.register_loss("custom_diversity_loss1", custom_diversity_loss) + # 注册combiner方法到全局注册器 + get_global_loss_registry._instance.register_combiner("custom_weighted_sum_combiner1", custom_weighted_sum_combiner) + return get_global_loss_registry._instance + + +def custom_diversity_loss( + routing_weights: paddle.Tensor, + selected_experts: paddle.Tensor, + gate_logits: Optional[paddle.Tensor] = None, + **kwargs + ) -> paddle.Tensor: + """自定义多样性损失""" + num_experts = kwargs.get('num_experts', 8) + expert_counts = paddle.zeros([num_experts]) + + for i in range(selected_experts.shape[0]): + for j in range(selected_experts.shape[1]): + expert_idx = selected_experts[i, j].item() + expert_counts[expert_idx] += 1 + + uniform_dist = paddle.ones_like(expert_counts) / expert_counts.shape[0] + expert_probs = expert_counts / (expert_counts.sum() + 1e-8) + + diversity_loss = paddle.nn.functional.kl_div( + paddle.log(expert_probs + 1e-8), + paddle.log(uniform_dist + 1e-8), + reduction='sum' + ) + + return diversity_loss + +def custom_weighted_sum_combiner( + self, + losses: Dict[str, paddle.Tensor], + configs: Dict[str, LossConfig] + ) -> paddle.Tensor: + """加权求和组合""" + combined_loss = paddle.to_tensor(0.0) + for name, loss_value in losses.items(): + config = configs.get(name) + if config and config.enabled: + combined_loss += config.weight * loss_value + return combined_loss + diff --git a/paddleformers/transformers/qwen3_moe/modeling.py b/paddleformers/transformers/qwen3_moe/modeling.py index 09d08f9d5d..3dc6249a6a 100644 --- a/paddleformers/transformers/qwen3_moe/modeling.py +++ b/paddleformers/transformers/qwen3_moe/modeling.py @@ -23,13 +23,19 @@ from paddle import Tensor, nn from paddle.distributed.fleet.utils import recompute from paddle.distributed.fleet.utils.sequence_parallel_utils import ScatterOp - +from paddle.distributed.fleet.meta_parallel import ( + LayerDesc, + PipelineLayer, + SharedLayerDesc, +) +from paddle.distributed.fleet.recompute.recompute import recompute from ...nn.attention.interface import ALL_ATTENTION_FUNCTIONS from ...nn.criterion.interface import CriterionLayer from ...nn.embedding import Embedding as GeneralEmbedding from ...nn.linear import Linear as GeneralLinear from ...nn.lm_head import LMHead as GeneralLMHead from ...nn.mlp import MLP +from ...nn.moe_deepep.moe_factory import QuickAccessMoEFactory from ...nn.norm import Norm as GeneralNorm from ...nn.pp_model import GeneralModelForCausalLMPipe from ...utils.log import logger @@ -347,7 +353,8 @@ def __init__(self, config: Qwen3MoeConfig, layer_idx: int): self.self_attn = Qwen3MoeAttention(config, layer_idx) if config.num_experts > 0: - self.mlp = Qwen3MoeSparseMoeBlock(config) + self.mlp = QuickAccessMoEFactory.create_from_model_name(config) + # self.mlp = Qwen3MoeSparseMoeBlock(config) else: # num_experts == 0 or this layer is not sparse layer self.mlp = Qwen3MoeMLP(config) @@ -858,6 +865,8 @@ class Qwen3MoeForCausalLM(Qwen3MoePretrainedModel): def __init__(self, config: Qwen3MoeConfig): super().__init__(config) + # config.num_hidden_layers = 4 + # config.num_experts = 4 self.model = Qwen3MoeModel(config) self.lm_head = GeneralLMHead(config) self.criterion = CriterionLayer(config)