diff --git a/examples/pre-training/ernie/pretrain_auto.py b/examples/pre-training/ernie/pretrain_auto.py index 5d0e7e64a..ee068af95 100644 --- a/examples/pre-training/ernie/pretrain_auto.py +++ b/examples/pre-training/ernie/pretrain_auto.py @@ -37,7 +37,10 @@ ErnieMoEConfig, ) -from src.callbacks import GlobalRNGCallback +from src.callbacks_auto import ( + GlobalRNGCallback, + MoECorrectionBiasAdjustCallback, +) from src.tokenizers.tokenization_eb_v2 import ErnieBotTokenizer from src.trainers import AutoPretrainingTrainer, AutoPreTrainingArguments from src.utils_auto import setup_logger_output_file, logger @@ -524,6 +527,13 @@ def main(): # 6. prepare for train/eval callbacks = [GlobalRNGCallback()] + if getattr(cfg, "moe_use_aux_free", 0.0) > 0.0: + logger.info("adding aux free callback") + callbacks += [ + MoECorrectionBiasAdjustCallback( + args.moe_use_aux_free_update_coef, args.sequence_parallel + ) + ] init_parameters(model) trainer = AutoPretrainingTrainer( diff --git a/examples/pre-training/ernie/src/callbacks_auto/__init__.py b/examples/pre-training/ernie/src/callbacks_auto/__init__.py index 26eac2da4..8fcd377dd 100644 --- a/examples/pre-training/ernie/src/callbacks_auto/__init__.py +++ b/examples/pre-training/ernie/src/callbacks_auto/__init__.py @@ -16,10 +16,12 @@ from .stopper_callback import StopperCallback from .moe_logging_callback import GlobalRNGCallback from .tensorboard_callback import TensorBoardCallback +from .moe_correction_bias_adjust_callback import MoECorrectionBiasAdjustCallback __all__ = [ "TensorBoardCallback", "LoggingCallback", "GlobalRNGCallback", "StopperCallback", + "MoECorrectionBiasAdjustCallback", ] diff --git a/examples/pre-training/ernie/src/callbacks_auto/moe_correction_bias_adjust_callback.py b/examples/pre-training/ernie/src/callbacks_auto/moe_correction_bias_adjust_callback.py new file mode 100644 index 000000000..42aecb780 --- /dev/null +++ b/examples/pre-training/ernie/src/callbacks_auto/moe_correction_bias_adjust_callback.py @@ -0,0 +1,81 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import paddle +import paddle.distributed as dist +from models.ernie.modeling_auto import ErnieDecoderLayerAuto +from models.moe.moe_layer_auto import MOELayerAuto +from paddle.distributed.fleet import fleet +from paddleformers.trainer.trainer_callback import TrainerCallback + + +class MoECorrectionBiasAdjustCallback(TrainerCallback): + def __init__(self, lr, use_sp): + super().__init__() + self.update_lr = float(lr) + self.use_sp = use_sp + + def on_optimizer_end(self, args, state, control, **kwargs): + model = kwargs["model"] + + usages = {} + biases = {} + + def get_stat(layer): + nonlocal usages, biases + if isinstance(layer, ErnieDecoderLayerAuto): + if not isinstance(layer.mlp, (MOELayerAuto)): + return + assert hasattr( + layer.mlp, "moe_statics" + ), "make sure update to latest ernie-core, too use AuxFree Balance" + usages[layer.layer_idx] = layer.mlp.moe_statics.expert_usage + biases[layer.layer_idx] = layer.mlp.moe_statics.e_score_correction_bias + + model.apply(get_stat) + if not usages: + return + keys, tensor_list = zip(*sorted(usages.items(), key=lambda x: x[0])) + usages_tensor = paddle.stack(tensor_list, 0) + if not hasattr(fleet, "_hcg"): + dist.all_reduce(usages_tensor) + return + + hcg = fleet.get_hybrid_communicate_group() + mp_group = hcg.get_model_parallel_group() + dp_group = hcg.get_data_parallel_group() + sd_group = hcg.get_sharding_parallel_group() + if self.use_sp and mp_group.nranks > 1: + dist.all_reduce(usages_tensor._local_value(), group=mp_group) + if dp_group.nranks > 1: + dist.all_reduce(usages_tensor._local_value(), group=dp_group) + if sd_group.nranks > 1: + dist.all_reduce(usages_tensor._local_value(), group=sd_group) + usages_mean = usages_tensor.mean(-1, keepdim=True) + update = paddle.sign(usages_mean - usages_tensor) * self.update_lr + update_dict = dict(zip(keys, update)) + + def update_bias(layer): + nonlocal usages, biases + if isinstance(layer, ErnieDecoderLayerAuto): + if not isinstance(layer.mlp, MOELayerAuto): + return + with paddle.no_grad(): + if layer.mlp.gate.weight.stop_gradient: + update_dict[layer.layer_idx][0, :] = 0 + biases[layer.layer_idx].add_(update_dict[layer.layer_idx]) + usages[layer.layer_idx].data.zero_() + + model.apply(update_bias) diff --git a/examples/pre-training/ernie/src/trainers/pretraining_trainer_auto.py b/examples/pre-training/ernie/src/trainers/pretraining_trainer_auto.py index 6aebf26ad..9016e3650 100644 --- a/examples/pre-training/ernie/src/trainers/pretraining_trainer_auto.py +++ b/examples/pre-training/ernie/src/trainers/pretraining_trainer_auto.py @@ -182,9 +182,9 @@ class AutoPreTrainingArguments(AutoTrainingArguments): default="ernie", metadata={"help": "Only support for ernie pre-training for now."}, ) - n_microbatches: int = field( - default=1, - metadata={"help": "Control the num of microbatches in one pp step."}, + moe_use_aux_free_update_coef: float = field( + default=1.0e-3, + metadata={"help": "moe aux free update coef"}, ) @property diff --git a/examples/pre-training/models/ernie/modeling_auto.py b/examples/pre-training/models/ernie/modeling_auto.py index 58b193346..433669060 100644 --- a/examples/pre-training/models/ernie/modeling_auto.py +++ b/examples/pre-training/models/ernie/modeling_auto.py @@ -48,6 +48,7 @@ from models.comm_utils import subbatch from models.moe.moe_layer_auto import ( MOELayerAuto, + MoEStatics, ) from models.ernie.configuration_auto import ErnieMoEConfig from models.moe.moe_utils_auto import get_mesh @@ -482,7 +483,11 @@ def get_gate( ) experts[i].ep_group_id = ep_group_id - return gate, experts, lm_gate, lm_experts + if config.moe_use_aux_free: + moe_statics = MoEStatics(config, layer_idx) + else: + moe_statics = None + return gate, experts, lm_gate, lm_experts, moe_statics class RMSNorm(nn.Layer): @@ -1218,7 +1223,7 @@ def create_moe_mlp_layer(self, layer_idx, ipp): fc = [(_ex_cfg.moe_num_experts, fc_cls(_ex_cfg))] else: fc = [(_ex_cfg.moe_num_experts, fc_cls(_ex_cfg))] - gate, experts, lm_gate, lm_experts = get_gate( + gate, experts, lm_gate, lm_experts, moe_statics = get_gate( self.config, fc, layer_idx, self.ipp ) _sh_cfg = deepcopy(self.config) @@ -1259,6 +1264,7 @@ def create_moe_mlp_layer(self, layer_idx, ipp): k=self.config.moe_k, all_to_all_dropout=self.config.moe_all_to_all_dropout, group_experts=self.config.moe_group_experts, + moe_statics=moe_statics, config=self.config, ipp=self.ipp, ) diff --git a/examples/pre-training/models/moe/moe_layer_auto.py b/examples/pre-training/models/moe/moe_layer_auto.py index 9edfcdf52..7f07aa695 100644 --- a/examples/pre-training/models/moe/moe_layer_auto.py +++ b/examples/pre-training/models/moe/moe_layer_auto.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect from typing import Tuple, List, Optional import logging from contextlib import contextmanager @@ -35,6 +36,56 @@ logger = logging.getLogger(__name__) +class MoEStatics(nn.Layer): + """ + Stores MoE (Mixture of Experts) statistics + and expert usage information. + """ + + def __init__(self, config, layer_idx): + """ + Initialize MoE statistics tracking. + + Args: + config: Model configuration containing MoE parameters + layer_idx: Index of the MoE layer in the model + """ + super().__init__() + self._cast_to_low_precision = False + self._cast_to_low_precison = False + num_experts = ( + config.moe_num_experts[0] + if config.multimodel_experts + else config.moe_num_experts + ) + if config.multimodel_experts: + assert ( + len(set(config.moe_num_experts)) == 1 + ), f"assume expert group has same size, got: {config.moe_num_experts}" + + with paddle.utils.unique_name.guard(f"mm_layer_{layer_idx}_"): + num_experts_groups = ( + len(config.moe_num_experts) if config.multimodel_experts else 1 + ) + p = self.create_parameter( + shape=[num_experts_groups, num_experts], + dtype="float32", + is_bias=True, + attr=paddle.ParamAttr( + name=paddle.utils.unique_name.generate("corr_bias") + ), + ) + p.stop_gradient = True + self.e_score_correction_bias = p + self.e_score_correction_bias.is_distributed = True + p = paddle.zeros( + shape=[num_experts_groups, num_experts], + dtype="int64", + ) + p.stop_gradient = True + self.expert_usage = p + + @contextmanager def profile(name): """doc""" @@ -529,13 +580,34 @@ def gate_and_distpach(self, input, token_type_ids): prob, max_prob = self.fused_gate_logits_process( gate_logits, token_type_ids ) + if "corr_bias" in inspect.signature(moe_gate_dispatch).parameters: + if self.use_correction_bias: + compat_args = (self.moe_statics.e_score_correction_bias[0],) + else: + compat_args = (None,) + else: + assert ( + not self.use_correction_bias + ), "correction bias not supported, rebuild moe-ops" + compat_args = () ( dispatched_input, combine_weights_unnorm, scatter_index, dispatch_mask, _, - ) = moe_gate_dispatch(input, prob, None, k, local_capacity, True) + ) = moe_gate_dispatch( + input, prob, *compat_args, k, local_capacity, True + ) + dispatch_mask = paddle.diff(F.pad(dispatch_mask, (1, 0))) + if self.use_correction_bias: + if self.gate.config.multimodel_experts: + for i in range(len(self.moe_statics.expert_usage)): + self.moe_statics.expert_usage[i] += dispatch_mask[ + self.gate.experts_type_mask[i] + ].detach() + else: + self.moe_statics.expert_usage[0] += dispatch_mask.detach() dispatched_input.stop_gradient = False combine_weights_unnorm.stop_gradient = False dispatch_mask.stop_gradient = True