Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion examples/pre-training/ernie/pretrain_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@
ErnieMoEConfig,
)

from src.callbacks import GlobalRNGCallback
from src.callbacks_auto import (
GlobalRNGCallback,
MoECorrectionBiasAdjustCallback,
)
from src.tokenizers.tokenization_eb_v2 import ErnieBotTokenizer
from src.trainers import AutoPretrainingTrainer, AutoPreTrainingArguments
from src.utils_auto import setup_logger_output_file, logger
Expand Down Expand Up @@ -524,6 +527,13 @@ def main():

# 6. prepare for train/eval
callbacks = [GlobalRNGCallback()]
if getattr(cfg, "moe_use_aux_free", 0.0) > 0.0:
logger.info("adding aux free callback")
callbacks += [
MoECorrectionBiasAdjustCallback(
args.moe_use_aux_free_update_coef, args.sequence_parallel
)
]
init_parameters(model)

trainer = AutoPretrainingTrainer(
Expand Down
2 changes: 2 additions & 0 deletions examples/pre-training/ernie/src/callbacks_auto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
from .stopper_callback import StopperCallback
from .moe_logging_callback import GlobalRNGCallback
from .tensorboard_callback import TensorBoardCallback
from .moe_correction_bias_adjust_callback import MoECorrectionBiasAdjustCallback

__all__ = [
"TensorBoardCallback",
"LoggingCallback",
"GlobalRNGCallback",
"StopperCallback",
"MoECorrectionBiasAdjustCallback",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import paddle
import paddle.distributed as dist
from models.ernie.modeling_auto import ErnieDecoderLayerAuto
from models.moe.moe_layer_auto import MOELayerAuto
from paddle.distributed.fleet import fleet
from paddleformers.trainer.trainer_callback import TrainerCallback


class MoECorrectionBiasAdjustCallback(TrainerCallback):
def __init__(self, lr, use_sp):
super().__init__()
self.update_lr = float(lr)
self.use_sp = use_sp

def on_optimizer_end(self, args, state, control, **kwargs):
model = kwargs["model"]

usages = {}
biases = {}

def get_stat(layer):
nonlocal usages, biases
if isinstance(layer, ErnieDecoderLayerAuto):
if not isinstance(layer.mlp, (MOELayerAuto)):
return
assert hasattr(
layer.mlp, "moe_statics"
), "make sure update to latest ernie-core, too use AuxFree Balance"
usages[layer.layer_idx] = layer.mlp.moe_statics.expert_usage
biases[layer.layer_idx] = layer.mlp.moe_statics.e_score_correction_bias

model.apply(get_stat)
if not usages:
return
keys, tensor_list = zip(*sorted(usages.items(), key=lambda x: x[0]))
usages_tensor = paddle.stack(tensor_list, 0)
if not hasattr(fleet, "_hcg"):
dist.all_reduce(usages_tensor)
return

hcg = fleet.get_hybrid_communicate_group()
mp_group = hcg.get_model_parallel_group()
dp_group = hcg.get_data_parallel_group()
sd_group = hcg.get_sharding_parallel_group()
if self.use_sp and mp_group.nranks > 1:
dist.all_reduce(usages_tensor._local_value(), group=mp_group)
if dp_group.nranks > 1:
dist.all_reduce(usages_tensor._local_value(), group=dp_group)
if sd_group.nranks > 1:
dist.all_reduce(usages_tensor._local_value(), group=sd_group)
usages_mean = usages_tensor.mean(-1, keepdim=True)
update = paddle.sign(usages_mean - usages_tensor) * self.update_lr
update_dict = dict(zip(keys, update))

def update_bias(layer):
nonlocal usages, biases
if isinstance(layer, ErnieDecoderLayerAuto):
if not isinstance(layer.mlp, MOELayerAuto):
return
with paddle.no_grad():
if layer.mlp.gate.weight.stop_gradient:
update_dict[layer.layer_idx][0, :] = 0
biases[layer.layer_idx].add_(update_dict[layer.layer_idx])
usages[layer.layer_idx].data.zero_()

model.apply(update_bias)
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,9 @@ class AutoPreTrainingArguments(AutoTrainingArguments):
default="ernie",
metadata={"help": "Only support for ernie pre-training for now."},
)
n_microbatches: int = field(
default=1,
metadata={"help": "Control the num of microbatches in one pp step."},
moe_use_aux_free_update_coef: float = field(
default=1.0e-3,
metadata={"help": "moe aux free update coef"},
)

@property
Expand Down
10 changes: 8 additions & 2 deletions examples/pre-training/models/ernie/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from models.comm_utils import subbatch
from models.moe.moe_layer_auto import (
MOELayerAuto,
MoEStatics,
)
from models.ernie.configuration_auto import ErnieMoEConfig
from models.moe.moe_utils_auto import get_mesh
Expand Down Expand Up @@ -482,7 +483,11 @@ def get_gate(
)
experts[i].ep_group_id = ep_group_id

return gate, experts, lm_gate, lm_experts
if config.moe_use_aux_free:
moe_statics = MoEStatics(config, layer_idx)
else:
moe_statics = None
return gate, experts, lm_gate, lm_experts, moe_statics


class RMSNorm(nn.Layer):
Expand Down Expand Up @@ -1218,7 +1223,7 @@ def create_moe_mlp_layer(self, layer_idx, ipp):
fc = [(_ex_cfg.moe_num_experts, fc_cls(_ex_cfg))]
else:
fc = [(_ex_cfg.moe_num_experts, fc_cls(_ex_cfg))]
gate, experts, lm_gate, lm_experts = get_gate(
gate, experts, lm_gate, lm_experts, moe_statics = get_gate(
self.config, fc, layer_idx, self.ipp
)
_sh_cfg = deepcopy(self.config)
Expand Down Expand Up @@ -1259,6 +1264,7 @@ def create_moe_mlp_layer(self, layer_idx, ipp):
k=self.config.moe_k,
all_to_all_dropout=self.config.moe_all_to_all_dropout,
group_experts=self.config.moe_group_experts,
moe_statics=moe_statics,
config=self.config,
ipp=self.ipp,
)
Expand Down
74 changes: 73 additions & 1 deletion examples/pre-training/models/moe/moe_layer_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
from typing import Tuple, List, Optional
import logging
from contextlib import contextmanager
Expand All @@ -35,6 +36,56 @@
logger = logging.getLogger(__name__)


class MoEStatics(nn.Layer):
"""
Stores MoE (Mixture of Experts) statistics
and expert usage information.
"""

def __init__(self, config, layer_idx):
"""
Initialize MoE statistics tracking.

Args:
config: Model configuration containing MoE parameters
layer_idx: Index of the MoE layer in the model
"""
super().__init__()
self._cast_to_low_precision = False
self._cast_to_low_precison = False
num_experts = (
config.moe_num_experts[0]
if config.multimodel_experts
else config.moe_num_experts
)
if config.multimodel_experts:
assert (
len(set(config.moe_num_experts)) == 1
), f"assume expert group has same size, got: {config.moe_num_experts}"

with paddle.utils.unique_name.guard(f"mm_layer_{layer_idx}_"):
num_experts_groups = (
len(config.moe_num_experts) if config.multimodel_experts else 1
)
p = self.create_parameter(
shape=[num_experts_groups, num_experts],
dtype="float32",
is_bias=True,
attr=paddle.ParamAttr(
name=paddle.utils.unique_name.generate("corr_bias")
),
)
p.stop_gradient = True
self.e_score_correction_bias = p
self.e_score_correction_bias.is_distributed = True
p = paddle.zeros(
shape=[num_experts_groups, num_experts],
dtype="int64",
)
p.stop_gradient = True
self.expert_usage = p


@contextmanager
def profile(name):
"""doc"""
Expand Down Expand Up @@ -529,13 +580,34 @@ def gate_and_distpach(self, input, token_type_ids):
prob, max_prob = self.fused_gate_logits_process(
gate_logits, token_type_ids
)
if "corr_bias" in inspect.signature(moe_gate_dispatch).parameters:
if self.use_correction_bias:
compat_args = (self.moe_statics.e_score_correction_bias[0],)
else:
compat_args = (None,)
else:
assert (
not self.use_correction_bias
), "correction bias not supported, rebuild moe-ops"
compat_args = ()
(
dispatched_input,
combine_weights_unnorm,
scatter_index,
dispatch_mask,
_,
) = moe_gate_dispatch(input, prob, None, k, local_capacity, True)
) = moe_gate_dispatch(
input, prob, *compat_args, k, local_capacity, True
)
dispatch_mask = paddle.diff(F.pad(dispatch_mask, (1, 0)))
if self.use_correction_bias:
if self.gate.config.multimodel_experts:
for i in range(len(self.moe_statics.expert_usage)):
self.moe_statics.expert_usage[i] += dispatch_mask[
self.gate.experts_type_mask[i]
].detach()
else:
self.moe_statics.expert_usage[0] += dispatch_mask.detach()
dispatched_input.stop_gradient = False
combine_weights_unnorm.stop_gradient = False
dispatch_mask.stop_gradient = True
Expand Down