From ee8d2d71ad9f317bab2866d5e41890c6af6863ed Mon Sep 17 00:00:00 2001 From: lizhenxing02 Date: Wed, 27 Aug 2025 16:37:17 +0800 Subject: [PATCH 1/3] adapt aux free --- .../models/ernie/modeling_auto.py | 10 ++- .../pre-training/models/moe/moe_layer_auto.py | 81 ++++++++++++++++++- 2 files changed, 87 insertions(+), 4 deletions(-) diff --git a/examples/pre-training/models/ernie/modeling_auto.py b/examples/pre-training/models/ernie/modeling_auto.py index cef7c4ccb..32588c27a 100644 --- a/examples/pre-training/models/ernie/modeling_auto.py +++ b/examples/pre-training/models/ernie/modeling_auto.py @@ -49,6 +49,7 @@ from models.moe.moe_layer_auto import ( MOELayerAuto, + MoEStatics, ) from models.ernie.configuration_auto import ErnieMoEConfig from models.moe.moe_utils_auto import get_mesh @@ -603,7 +604,11 @@ def get_gate( ) experts[i].ep_group_id = ep_group_id - return gate, experts, lm_gate, lm_experts + if config.moe_use_aux_free: + moe_statics = MoEStatics(config, layer_idx) + else: + moe_statics = None + return gate, experts, lm_gate, lm_experts, moe_statics def _parse_moe_group(moe_group: str): @@ -1433,7 +1438,7 @@ def create_moe_mlp_layer(self, layer_idx, ipp): fc = [(_ex_cfg.moe_num_experts, fc_cls(_ex_cfg))] else: fc = [(_ex_cfg.moe_num_experts, fc_cls(_ex_cfg))] - gate, experts, lm_gate, lm_experts = get_gate( + gate, experts, lm_gate, lm_experts, moe_statics = get_gate( self.config, fc, layer_idx, self.ipp ) _sh_cfg = deepcopy(self.config) @@ -1476,6 +1481,7 @@ def create_moe_mlp_layer(self, layer_idx, ipp): enable_pbr=self.config.moe_use_bpr, all_to_all_dropout=self.config.moe_all_to_all_dropout, group_experts=self.config.moe_group_experts, + moe_statics=moe_statics, config=self.config, ipp=self.ipp, ) diff --git a/examples/pre-training/models/moe/moe_layer_auto.py b/examples/pre-training/models/moe/moe_layer_auto.py index b267cce78..e1ae19661 100644 --- a/examples/pre-training/models/moe/moe_layer_auto.py +++ b/examples/pre-training/models/moe/moe_layer_auto.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. - +import inspect from typing import Tuple, List, Optional import logging from contextlib import contextmanager @@ -62,6 +62,56 @@ ) +class MoEStatics(nn.Layer): + """ + Stores MoE (Mixture of Experts) statistics + and expert usage information. + """ + + def __init__(self, config, layer_idx): + """ + Initialize MoE statistics tracking. + + Args: + config: Model configuration containing MoE parameters + layer_idx: Index of the MoE layer in the model + """ + super().__init__() + self._cast_to_low_precision = False + self._cast_to_low_precison = False + num_experts = ( + config.moe_num_experts[0] + if config.multimodel_experts + else config.moe_num_experts + ) + if config.multimodel_experts: + assert ( + len(set(config.moe_num_experts)) == 1 + ), f"assume expert group has same size, got: {config.moe_num_experts}" + + with paddle.utils.unique_name.guard(f"mm_layer_{layer_idx}_"): + num_experts_groups = ( + len(config.moe_num_experts) if config.multimodel_experts else 1 + ) + p = self.create_parameter( + shape=[num_experts_groups, num_experts], + dtype="float32", + is_bias=True, + attr=paddle.ParamAttr( + name=paddle.utils.unique_name.generate("corr_bias") + ), + ) + p.stop_gradient = True + self.e_score_correction_bias = p + self.e_score_correction_bias.is_distributed = True + p = paddle.zeros( + shape=[num_experts_groups, num_experts], + dtype="int64", + ) + p.stop_gradient = True + self.expert_usage = p + + @contextmanager def profile(name): """doc""" @@ -518,10 +568,13 @@ def __init__( enable_pbr: bool = False, all_to_all_dropout=0, group_experts=False, + moe_statics=None, config=None, ipp=0, ): nn.Layer.__init__(self) + self.use_correction_bias = moe_statics is not None + self.moe_statics = moe_statics self.config = config self.gate = gate self.layer_idx = layer_idx @@ -677,6 +730,21 @@ def gate_and_distpach(self, input, token_type_ids): prob, max_prob = self.fused_gate_logits_process( gate_logits, token_type_ids ) + if ( + "corr_bias" + in inspect.signature( + paddle.incubate.nn.functional.moe_gate_dispatch + ).parameters + ): + if self.use_correction_bias: + compat_args = (self.moe_statics.e_score_correction_bias[0],) + else: + compat_args = (None,) + else: + assert ( + not self.use_correction_bias + ), "correction bias not supported, rebuild moe-ops" + compat_args = () ( dispatched_input, combine_weights_unnorm, @@ -684,8 +752,17 @@ def gate_and_distpach(self, input, token_type_ids): dispatch_mask, _, ) = paddle.incubate.nn.functional.moe_gate_dispatch( - input, prob, None, k, local_capacity, True + input, prob, *compat_args, k, local_capacity, True ) + dispatch_mask = paddle.diff(F.pad(dispatch_mask, (1, 0))) + if self.use_correction_bias: + if self.gate.config.multimodel_experts: + for i in range(len(self.moe_statics.expert_usage)): + self.moe_statics.expert_usage[i] += dispatch_mask[ + self.gate.experts_type_mask[i] + ].detach() + else: + self.moe_statics.expert_usage[0] += dispatch_mask.detach() dispatched_input.stop_gradient = False combine_weights_unnorm.stop_gradient = False dispatch_mask.stop_gradient = True From 917f3cb78ac1379168c72c7e8afc0e1a217dc0d2 Mon Sep 17 00:00:00 2001 From: lizhenxing02 Date: Thu, 28 Aug 2025 15:26:45 +0800 Subject: [PATCH 2/3] aux free callback --- examples/pre-training/ernie/pretrain_auto.py | 12 ++- .../ernie/src/callbacks_auto/__init__.py | 2 + .../moe_correction_bias_adjust_callback.py | 81 +++++++++++++++++++ .../src/trainers/pretraining_trainer_auto.py | 6 +- 4 files changed, 97 insertions(+), 4 deletions(-) create mode 100644 examples/pre-training/ernie/src/callbacks_auto/moe_correction_bias_adjust_callback.py diff --git a/examples/pre-training/ernie/pretrain_auto.py b/examples/pre-training/ernie/pretrain_auto.py index 90fc656f3..a2e584730 100644 --- a/examples/pre-training/ernie/pretrain_auto.py +++ b/examples/pre-training/ernie/pretrain_auto.py @@ -37,7 +37,10 @@ ErnieMoEConfig, ) -from src.callbacks import GlobalRNGCallback +from src.callbacks_auto import ( + GlobalRNGCallback, + MoECorrectionBiasAdjustCallback, +) from src.tokenizers.tokenization_eb_v2 import ErnieBotTokenizer from src.trainers import AutoPretrainingTrainer, AutoPreTrainingArguments from src.utils_auto import setup_logger_output_file, logger @@ -539,6 +542,13 @@ def main(): # 6. prepare for train/eval callbacks = [GlobalRNGCallback()] + if getattr(cfg, "moe_use_aux_free", 0.0) > 0.0: + logger.info("adding aux free callback") + callbacks += [ + MoECorrectionBiasAdjustCallback( + args.moe_use_aux_free_update_coef, args.sequence_parallel + ) + ] init_parameters(model) trainer = AutoPretrainingTrainer( diff --git a/examples/pre-training/ernie/src/callbacks_auto/__init__.py b/examples/pre-training/ernie/src/callbacks_auto/__init__.py index 26eac2da4..8fcd377dd 100644 --- a/examples/pre-training/ernie/src/callbacks_auto/__init__.py +++ b/examples/pre-training/ernie/src/callbacks_auto/__init__.py @@ -16,10 +16,12 @@ from .stopper_callback import StopperCallback from .moe_logging_callback import GlobalRNGCallback from .tensorboard_callback import TensorBoardCallback +from .moe_correction_bias_adjust_callback import MoECorrectionBiasAdjustCallback __all__ = [ "TensorBoardCallback", "LoggingCallback", "GlobalRNGCallback", "StopperCallback", + "MoECorrectionBiasAdjustCallback", ] diff --git a/examples/pre-training/ernie/src/callbacks_auto/moe_correction_bias_adjust_callback.py b/examples/pre-training/ernie/src/callbacks_auto/moe_correction_bias_adjust_callback.py new file mode 100644 index 000000000..f986ba18b --- /dev/null +++ b/examples/pre-training/ernie/src/callbacks_auto/moe_correction_bias_adjust_callback.py @@ -0,0 +1,81 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import paddle +import paddle.distributed as dist +from models.ernie.modeling_auto import ErnieDecoderLayerAuto +from models.moe.moe_layer_auto import MOELayerAuto +from paddle.distributed.fleet import fleet +from paddleformers.trainer.trainer_callback import TrainerCallback + + +class MoECorrectionBiasAdjustCallback(TrainerCallback): + def __init__(self, lr, use_sp): + super().__init__() + self.update_lr = float(lr) + self.use_sp = use_sp + + def on_optimizer_end(self, args, state, control, **kwargs): + model = kwargs["model"] + + usages = {} + biases = {} + + def get_stat(layer): + nonlocal usages, biases + if isinstance(layer, ErnieDecoderLayerAuto): + if not isinstance(layer.mlp, (MOELayerAuto)): + return + assert hasattr( + layer.mlp, "moe_statics" + ), "make sure update to latest ernie-core, too use AuxFree Balance" + usages[layer.layer_idx] = layer.mlp.moe_statics.expert_usage + biases[layer.layer_idx] = layer.mlp.moe_statics.e_score_correction_bias + + model.apply(get_stat) + if not usages: + return + keys, tensor_list = zip(*sorted(usages.items(), key=lambda x: x[0])) + usages_tensor = paddle.stack(tensor_list, 0) + if not hasattr(fleet, "_hcg"): + dist.all_reduce(usages_tensor) + return + + # hcg = fleet.get_hybrid_communicate_group() + # mp_group = hcg.get_model_parallel_group() + # dp_group = hcg.get_data_parallel_group() + # sd_group = hcg.get_sharding_parallel_group() + # if self.use_sp and mp_group.nranks > 1: + # dist.all_reduce(usages_tensor, group=mp_group) + # if dp_group.nranks > 1: + # dist.all_reduce(usages_tensor, group=dp_group) + # if sd_group.nranks > 1: + # dist.all_reduce(usages_tensor, group=sd_group) + usages_mean = usages_tensor.mean(-1, keepdim=True) + update = paddle.sign(usages_mean - usages_tensor) * self.update_lr + update_dict = dict(zip(keys, update)) + + def update_bias(layer): + nonlocal usages, biases + if isinstance(layer, ErnieDecoderLayerAuto): + if not isinstance(layer.mlp, MOELayerAuto): + return + with paddle.no_grad(): + if layer.mlp.gate.weight.stop_gradient: + update_dict[layer.layer_idx][0, :] = 0 + biases[layer.layer_idx].add_(update_dict[layer.layer_idx]) + usages[layer.layer_idx].data.zero_() + + model.apply(update_bias) diff --git a/examples/pre-training/ernie/src/trainers/pretraining_trainer_auto.py b/examples/pre-training/ernie/src/trainers/pretraining_trainer_auto.py index 7c2ba0c73..b0d2c6f0f 100644 --- a/examples/pre-training/ernie/src/trainers/pretraining_trainer_auto.py +++ b/examples/pre-training/ernie/src/trainers/pretraining_trainer_auto.py @@ -225,9 +225,9 @@ class AutoPreTrainingArguments(AutoTrainingArguments): default="ernie", metadata={"help": "Only support for ernie pre-training for now."}, ) - n_microbatches: int = field( - default=1, - metadata={"help": "Control the num of microbatches in one pp step."}, + moe_use_aux_free_update_coef: float = field( + default=1.0e-3, + metadata={"help": "moe aux free update coef"}, ) @property From c06c0c9cd87c2c759a05032819aa1d9656144ac6 Mon Sep 17 00:00:00 2001 From: ZhenxingLi Date: Thu, 28 Aug 2025 16:10:27 +0800 Subject: [PATCH 3/3] fix allreduce --- .../moe_correction_bias_adjust_callback.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/examples/pre-training/ernie/src/callbacks_auto/moe_correction_bias_adjust_callback.py b/examples/pre-training/ernie/src/callbacks_auto/moe_correction_bias_adjust_callback.py index f986ba18b..42aecb780 100644 --- a/examples/pre-training/ernie/src/callbacks_auto/moe_correction_bias_adjust_callback.py +++ b/examples/pre-training/ernie/src/callbacks_auto/moe_correction_bias_adjust_callback.py @@ -53,16 +53,16 @@ def get_stat(layer): dist.all_reduce(usages_tensor) return - # hcg = fleet.get_hybrid_communicate_group() - # mp_group = hcg.get_model_parallel_group() - # dp_group = hcg.get_data_parallel_group() - # sd_group = hcg.get_sharding_parallel_group() - # if self.use_sp and mp_group.nranks > 1: - # dist.all_reduce(usages_tensor, group=mp_group) - # if dp_group.nranks > 1: - # dist.all_reduce(usages_tensor, group=dp_group) - # if sd_group.nranks > 1: - # dist.all_reduce(usages_tensor, group=sd_group) + hcg = fleet.get_hybrid_communicate_group() + mp_group = hcg.get_model_parallel_group() + dp_group = hcg.get_data_parallel_group() + sd_group = hcg.get_sharding_parallel_group() + if self.use_sp and mp_group.nranks > 1: + dist.all_reduce(usages_tensor._local_value(), group=mp_group) + if dp_group.nranks > 1: + dist.all_reduce(usages_tensor._local_value(), group=dp_group) + if sd_group.nranks > 1: + dist.all_reduce(usages_tensor._local_value(), group=sd_group) usages_mean = usages_tensor.mean(-1, keepdim=True) update = paddle.sign(usages_mean - usages_tensor) * self.update_lr update_dict = dict(zip(keys, update))