diff --git a/examples/pre-training/ernie/pretrain_auto.py b/examples/pre-training/ernie/pretrain_auto.py new file mode 100644 index 00000000..3e8078c6 --- /dev/null +++ b/examples/pre-training/ernie/pretrain_auto.py @@ -0,0 +1,374 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +import json +import numpy as np +import random +import paddle +import paddle.distributed.fleet as fleet +from src.utils_auto import logger +from paddleformers.trainer import ( + PdArgumentParser, + get_last_checkpoint, +) +from src.tokenizers.tokenization_eb_v2 import ErnieBotTokenizer +from omegaconf.listconfig import ListConfig +from omegaconf.dictconfig import DictConfig +from src.callbacks_auto import ( + GlobalRNGCallback, +) +from models.ernie import ( + ErnieForCausalLMAuto, + ErnieForCausalLMAutoPP, +) +from models.ernie.configuration_auto import ( + ErnieConfig, + ErnieMoEConfig, +) +from src.trainers import AutoPretrainingTrainer, AutoPreTrainingArguments +from src.utils_auto import ( + setup_logger_output_file, +) +from src.utils_auto.misc import global_training_logs + +from paddleformers.data.causal_dataset import ( + build_train_valid_test_datasets, + check_data_split, +) + + +from config import get_config + +try: + from paddleformers.trainer.trainer_utils import log_trainer_start +except ImportError: + + def log_trainer_start(): + """Print main process messgae""" + if "MAIN_PROCESS_STARTED" not in os.environ: + start_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + logger.info( + f"The Training Main Process Started Successfully. time: {start_time}, pid: {os.getpid()}" + ) + os.environ["MAIN_PROCESS_STARTED"] = "1" + + +from paddle.distributed.fleet import collective_perf + +log_trainer_start() + + +def create_pretrained_dataset(args): + assert args.input_dir is not None and len(args.input_dir.split()) > 1 + + check_data_split( + args.split, + args.do_train, + args.do_eval, + args.do_predict, + ) + + train_val_test_num_samples = [ + args.per_device_train_batch_size + * args.dataset_world_size + * args.max_steps + * args.gradient_accumulation_steps, + args.per_device_eval_batch_size + * args.dataset_world_size + * args.eval_iters + * (args.max_steps // args.eval_steps + 1), + args.per_device_eval_batch_size * args.dataset_world_size * args.test_iters, + ] + + train_dataset, valid_dataset, test_dataset = build_train_valid_test_datasets( + data_prefix=args.input_dir.split(), + data_impl="mmap", + splits_string=args.split, + train_val_test_num_samples=train_val_test_num_samples, + seq_length=args.max_seq_length + args.multi_token_pred_depth, + seed=args.seed, + skip_warmup=True, + data_cache_path=None, + ) + + from paddleformers.data import Stack + + def _collate_data(data, stack_fn=Stack()): + tokens_ = stack_fn([x["text"] for x in data]) + + labels = tokens_[:, 1:] + tokens = tokens_[:, :-1] + + return { + "input_ids": tokens, + "labels": labels, + } + + return train_dataset, valid_dataset, test_dataset, _collate_data + + +def update_model_config_from_args(config: ErnieConfig, model_args: dict): + for k, v in model_args.items(): + if hasattr(config, k): + logger.info(f"update model config: {k} = {v}") + setattr(config, k, v) + else: + logger.warning(f"model config key: {k} does not exist") + return config + + +def init_parameter(model): + + for param in model.parameters(): + param.initialize() + + +def main(): + """Main function""" + config = get_config(verbose=True) + os.makedirs(config.model_args.output_dir, exist_ok=True) + parser = PdArgumentParser(AutoPreTrainingArguments) + if not hasattr(config.trainer_args, "pipeline_parallel_config"): + config.trainer_args.pipeline_parallel_config = "" + + if "enable_dp_comm_overlap" in config.trainer_args.pipeline_parallel_config: + logger.warning( + "Pipeline dp_comm_overlap and FusedLinearWithGradAdd can not be used at " + "the same time." + ) + + if "enable_timer" in config.trainer_args.pipeline_parallel_config: + from paddle.distributed.fleet.meta_parallel.pipeline_parallel import ( + PipelineParallel, + ) + + PipelineParallel.timer_printer = lambda _: None + + def formatv(v): + if isinstance(v, ListConfig): + return list(v) + elif isinstance(v, DictConfig): + return dict(v) + return v + + model_args = {k: formatv(v) for k, v in dict(config.model_args).items()} + trainer_args = {k: formatv(v) for k, v in dict(config.trainer_args).items()} + (args,) = parser.parse_dict(dict(**model_args, **trainer_args)) + + if args.strategy.pipeline.enable and args.virtual_pp_degree > 1: + pipeline = args.strategy.pipeline + pipeline.vpp_degree = args.virtual_pp_degree + pipeline.vpp_seg_method = args.virtual_pipeline_seg_method + + args.eval_iters = 10 + args.test_iters = args.eval_iters * 10 + + args.use_moe = dict(**dict(config.model_args), **dict(config.trainer_args)).get( + "use_moe", False + ) + model_config = dict(getattr(config.model_args, "model_config", {})) + model_config = {k: formatv(v) for k, v in model_config.items()} + logger.info(f"model_config_from_yaml: {json.dumps(model_config, indent=4)}") + setup_logger_output_file(config.model_args.output_dir, args.local_rank) + paddle.set_device(args.device) + np.random.seed(args.seed) + random.seed(args.seed) + paddle.seed(args.seed) + + prop = paddle.device.cuda.get_device_properties() + if prop.total_memory < args.pre_alloc_memory * 1024 * 1024 * 1024: + logger.warning( + "Invalid value for `pre_alloc_memory`, so pre-allocating just failed." + ) + elif args.pre_alloc_memory > 0: + logger.warning( + f"pre-allocating a tensor whose memory capacity is {args.pre_alloc_memory} GB " + "and then release it." + ) + memory_size = int(args.pre_alloc_memory * 1024 * 1024 * 1024) + x = paddle.empty([memory_size], dtype=paddle.uint8) + del x + + try: + collective_perf( + "allgather", + round=50, + size_and_time={67108864: 0.00625, 234881024: 0.02, 637534208: 0.057}, + ) + logger.info("======monitor allgather done!=======\n") + collective_perf( + "allreduce", + round=50, + size_and_time={67108864: 0.02, 134217728: 0.038, 268435456: 0.075}, + ) + logger.info("======monitor allreduce done!=======\n") + except Exception as e: + logger.warning(f"fleet test unexcepted error! skip exception[{e}]...") + + # Detecting last checkpoint. + last_checkpoint = None + if ( + os.path.isdir(args.output_dir) + and args.do_train + and not args.overwrite_output_dir + ): + last_checkpoint = get_last_checkpoint(args.output_dir) + if last_checkpoint is None and len(os.listdir(args.output_dir)) > 0: + raise ValueError( + f"Output directory ({args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None and args.resume_from_checkpoint is None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + # Define the metrics of tasks. + def compute_metrics(p): + preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions + + output = paddle.to_tensor(preds) + labels = paddle.to_tensor(p.label_ids) + output = [t.astype("float32").cuda() for t in output] + labels = [t[t != tokenizer.ignored_index] for t in labels] + labels = [t.cuda() for t in labels] + all_numel = ( + (paddle.concat(labels, 0) != tokenizer.ignored_index).astype("int64").sum() + ) + ignored = (paddle.concat(labels, 0) == -100).astype("int64").sum() + labels = all_numel - ignored + output = sum(output) + logger.info(f"output : {output.item()}, labels : {labels.item()}") + nll_loss = output / (labels + 1.0e-6) + ppl = paddle.exp(nll_loss) + + return { + "nll_loss": nll_loss.item(), + "ppl": ppl.item(), + "num_token": labels.item(), + } + + # model + dtype = "float32" + if args.fp16 and args.fp16_opt_level == "O2": + paddle.set_default_dtype("float16") + dtype = "float16" + elif args.bf16: + paddle.set_default_dtype("bfloat16") + dtype = "bfloat16" + + if args.use_moe: + global ErnieConfig, ErnieForCausalLMAuto + ErnieConfig = ErnieMoEConfig + + if args.moe_group.lower() in {"mp", "tp", "model", "dummy"}: + logger.info(f"disable moe flag when using moe-group={args.moe_group}") + args.use_moe = False + + args.multi_token_pred_depth = model_config.get("multi_token_pred_depth", 0) + + cfg = ErnieConfig.from_pretrained(args.model_name_or_path) + cfg.seqlen = args.max_seq_length + cfg.token_balance_seqlen = args.max_seq_length * args.per_device_train_batch_size + cfg.fp16_opt_level = args.fp16_opt_level + cfg.moe_group = args.moe_group + cfg.dtype = dtype + cfg.pipeline_parallel_degree = args.pipeline_parallel_degree + cfg.virtual_pp_degree = args.virtual_pp_degree + if args.tensor_parallel_degree > 1: + cfg.sequence_parallel = args.sequence_parallel + cfg.tensor_parallel_degree = max( + fleet.get_hybrid_communicate_group().get_model_parallel_world_size(), 1 + ) + cfg.tensor_parallel_rank = max( + fleet.get_hybrid_communicate_group().get_model_parallel_rank(), 0 + ) + else: + cfg.sequence_parallel = False + cfg.tensor_parallel_degree = 1 + cfg.tensor_parallel_rank = 0 + + cfg.micro_batch_size = args.per_device_train_batch_size + tokenizer = ErnieBotTokenizer.from_pretrained(args.tokenizer_name) + tokenizer.ignored_index = cfg.ignored_index + logger.info( + f"using tokenizer={type(tokenizer)}, bos:{tokenizer.bos_token_id} " + f"eos:{tokenizer.eos_token_id} pad:{tokenizer.pad_token_id} " + ) + + cfg = update_model_config_from_args(cfg, model_config) + + if args.model_type == "ernie": + model_class = ErnieForCausalLMAuto + elif args.model_type == "ernie_pp": + model_class = ErnieForCausalLMAutoPP + else: + raise ValueError(f"not support model_type: {args.model_type}") + + with paddle.LazyGuard(): + model = model_class(cfg) + + cfg = model.config + logger.info(f"using model type:{type(model)}") + paddle.set_default_dtype("float32") + + logger.info(f"using model={type(model)}, cfg={cfg}") + + # data + logger.info("loading data...") + train_dataset, eval_dataset, test_dataset, data_collator = ( + create_pretrained_dataset(args) + ) + + callbacks = [GlobalRNGCallback()] + + init_parameter(model) + model.apply(model.init_weights) + trainer = AutoPretrainingTrainer( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + compute_metrics=compute_metrics, + callbacks=callbacks, + ) + global_training_logs.accumulate = args.gradient_accumulation_steps + checkpoint = None + if args.resume_from_checkpoint is not None: + checkpoint = args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + + # Training + if args.do_train: + train_result = trainer.train(resume_from_checkpoint=checkpoint) + metrics = train_result.metrics + trainer.save_model(args.output_dir) + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + # Evaluate and tests model + if args.do_eval: + eval_metrics = trainer.evaluate() + trainer.log_metrics("eval", eval_metrics) + + +if __name__ == "__main__": + main() diff --git a/examples/pre-training/ernie/src/callbacks_auto/__init__.py b/examples/pre-training/ernie/src/callbacks_auto/__init__.py new file mode 100644 index 00000000..26eac2da --- /dev/null +++ b/examples/pre-training/ernie/src/callbacks_auto/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .logging_callback import LoggingCallback +from .stopper_callback import StopperCallback +from .moe_logging_callback import GlobalRNGCallback +from .tensorboard_callback import TensorBoardCallback + +__all__ = [ + "TensorBoardCallback", + "LoggingCallback", + "GlobalRNGCallback", + "StopperCallback", +] diff --git a/examples/pre-training/ernie/src/callbacks_auto/logging_callback.py b/examples/pre-training/ernie/src/callbacks_auto/logging_callback.py new file mode 100644 index 00000000..c2435a4b --- /dev/null +++ b/examples/pre-training/ernie/src/callbacks_auto/logging_callback.py @@ -0,0 +1,63 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from paddleformers.trainer.trainer_callback import TrainerCallback + +logger = logging.getLogger(__name__) + + +class LoggingCallback(TrainerCallback): + def __init__( + self, + ) -> None: + super().__init__() + + def on_log(self, args, state, control, logs=None, **kwargs): + _ = logs.pop("total_flos", None) + if "inputs" in kwargs: + data_id = kwargs["inputs"].get("data_id", None) + src_id = kwargs["inputs"].get("src_id", None) + data_type = kwargs["inputs"].get("data_type", None) + + if data_id is not None: + logs = dict( + logs, data_id="-".join(map(str, (data_id.numpy().tolist()))) + ) + if src_id is not None: + logs = dict(logs, src_id="-".join(map(str, (src_id.numpy().tolist())))) + if data_type is not None: + logs.update(data_type="-".join(map(str, (data_type.numpy().tolist())))) + + if type(logs) is dict: + logger.info( + ", ".join( + ( + ( + f"{k}: {v}" + if k == "loss" or "cur_dp" in k + else f"{k}: {v:e}" if v < 1e-3 else f"{k}: {v:f}" + ) + if isinstance(v, float) + else f"{k}: {v}" + ) + for k, v in logs.items() + ) + ) + metrics_dumper = kwargs.get("metrics_dumper", None) + if metrics_dumper is not None: + metrics_dumper.append(logs) + else: + logger.info(logs) diff --git a/examples/pre-training/ernie/src/callbacks_auto/moe_logging_callback.py b/examples/pre-training/ernie/src/callbacks_auto/moe_logging_callback.py new file mode 100644 index 00000000..e5bd2e86 --- /dev/null +++ b/examples/pre-training/ernie/src/callbacks_auto/moe_logging_callback.py @@ -0,0 +1,26 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +from paddleformers.trainer.trainer_callback import TrainerCallback +from models.ernie.modeling_moe import ErnieMoEForCausalLM + +__all__ = ["GlobalRNGCallback"] + + +class GlobalRNGCallback(TrainerCallback): + def on_step_end(self, args, state, control, model, **kwargs): + isinstance(model, ErnieMoEForCausalLM), type(model) + random.Random(state.global_step) diff --git a/examples/pre-training/ernie/src/callbacks_auto/stopper_callback.py b/examples/pre-training/ernie/src/callbacks_auto/stopper_callback.py new file mode 100644 index 00000000..2b776309 --- /dev/null +++ b/examples/pre-training/ernie/src/callbacks_auto/stopper_callback.py @@ -0,0 +1,29 @@ +# !/usr/bin/env python3 + +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import logging +from paddleformers.trainer.trainer_callback import TrainerCallback + +logger = logging.getLogger(__name__) + + +class StopperCallback(TrainerCallback): + + def on_substep_end(self, args, state, control, **kwargs): + if os.path.exists("/root/stop"): + control.should_training_stop = True diff --git a/examples/pre-training/ernie/src/callbacks_auto/tensorboard_callback.py b/examples/pre-training/ernie/src/callbacks_auto/tensorboard_callback.py new file mode 100644 index 00000000..f420d02c --- /dev/null +++ b/examples/pre-training/ernie/src/callbacks_auto/tensorboard_callback.py @@ -0,0 +1,202 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib.util +import json + +from paddleformers.peft.lora import LoRAModel +from paddleformers.trainer.trainer_callback import TrainerCallback +from paddleformers.transformers import PretrainedModel +from paddleformers.utils.log import logger + +try: + from paddleformers.trainer.trainer import clear_async_save_task_queue +except Exception: + clear_async_save_task_queue = None + + +def is_tensorboard_available(): + return ( + importlib.util.find_spec("tensorboard") is not None + or importlib.util.find_spec("tensorboardX") is not None + ) + + +def rewrite_logs(d): + new_d = {} + eval_prefix = "eval_" + eval_prefix_len = len(eval_prefix) + test_prefix = "test_" + test_prefix_len = len(test_prefix) + for k, v in d.items(): + if k.startswith(eval_prefix): + new_d["eval/" + k[eval_prefix_len:]] = v + elif k.startswith(test_prefix): + new_d["test/" + k[test_prefix_len:]] = v + else: + new_d["train/" + k] = v + return new_d + + +class TensorBoardCallback(TrainerCallback): + def __init__( + self, + args, + model, + tb_writer=None, + log_flops_per_step=False, + log_tokens_per_step=False, + ): + has_tensorboard = is_tensorboard_available() + if not has_tensorboard: + raise RuntimeError( + "TensorBoardCallback requires tensorboard to be installed. Either update or install tensorboardX." + ) + if has_tensorboard: + try: + from torch.utils.tensorboard import SummaryWriter + + self._SummaryWriter = SummaryWriter + except ImportError: + try: + from tensorboardX import SummaryWriter + + self._SummaryWriter = SummaryWriter + except ImportError: + self._SummaryWriter = None + else: + self._SummaryWriter = None + self.tb_writer = tb_writer + + def get_numel_item(p): + item = p.numel().item() + return item if item else 0 + + self.model_numel = sum( + get_numel_item(p) + for n, p in model.named_parameters() + if not p.stop_gradient and "embeddings" not in n and "embed_tokens" not in n + ) + self.log_flops_per_step = log_flops_per_step + self.log_tokens_per_step = log_tokens_per_step + + def _init_summary_writer(self, args, log_dir=None): + log_dir = log_dir or args.logging_dir + if self._SummaryWriter is not None: + self.tb_writer = self._SummaryWriter(log_dir=log_dir) + + def on_train_begin(self, args, state, control, **kwargs): + if not state.is_world_process_zero: + return + + log_dir = None + + if self.tb_writer is None: + self._init_summary_writer(args, log_dir) + + if self.tb_writer is not None: + self.tb_writer.add_text("args", args.to_json_string()) + if "model" in kwargs: + model = kwargs["model"] + + if ( + isinstance(model, PretrainedModel) + and model.constructed_from_pretrained_config() + ) or isinstance(model, LoRAModel): + model.config.architectures = [model.__class__.__name__] + self.tb_writer.add_text("model_config", str(model.config)) + + elif hasattr(model, "init_config") and model.init_config is not None: + model_config_json = json.dumps( + model.get_model_config(), ensure_ascii=False, indent=2 + ) + self.tb_writer.add_text("model_config", model_config_json) + + def on_log(self, args, state, control, logs=None, **kwargs): + if not state.is_world_process_zero: + return + + timers = kwargs.get("timers") + paddle_pipeline_timers = kwargs.get("paddle_pipeline_timers") + + if self.tb_writer is None: + self._init_summary_writer(args) + + if self.tb_writer is not None: + logs = rewrite_logs(logs) + + total_tokens_per_step = ( + args.train_batch_size + * args.gradient_accumulation_steps + * args.reeao_dataset_world_size + * args.max_seq_length + ) + + if self.log_flops_per_step: + logger.warning("The FLOPs might be not accurate") + flops_per_step = self.model_numel * total_tokens_per_step * 6 + else: + flops_per_step = None + + if self.log_tokens_per_step: + tokens_per_step = total_tokens_per_step + else: + tokens_per_step = None + inputs = kwargs.get("inputs") + data_type = inputs and inputs.get("data_type") + if data_type is not None: + data_type = data_type.tolist()[-1] + logs.update(data_type=data_type) + + for k, v in logs.items(): + if isinstance(v, (int, float)): + self.tb_writer.add_scalar(k, v, state.global_step) + + if tokens_per_step is not None and k in ["train/loss"]: + self.tb_writer.add_scalar( + k + "_xaxis_tokens", v, state.global_step * tokens_per_step + ) + + if flops_per_step is not None and k in ["train/loss"]: + self.tb_writer.add_scalar( + k + "_xaxis_flops", v, state.global_step * flops_per_step + ) + + else: + logger.warning( + "Trainer is attempting to log a value of " + f'"{v}" of type {type(v)} for key "{k}" as a scalar. ' + "This invocation of Tensorboard's writer.add_scalar() " + "is incorrect so we dropped this attribute." + ) + if timers is not None: + timers.write( + timers.timers.keys(), self.tb_writer, state.global_step, reset=False + ) + + if paddle_pipeline_timers: + for name, timer in paddle_pipeline_timers.timers.items(): + elapsed_time = timer.elapsed(reset=False) + self.tb_writer.add_scalar( + f"timers/{name}", elapsed_time, state.global_step + ) + + self.tb_writer.flush() + + def on_train_end(self, args, state, control, **kwargs): + if clear_async_save_task_queue: + clear_async_save_task_queue() + if self.tb_writer: + self.tb_writer.close() + self.tb_writer = None diff --git a/examples/pre-training/ernie/src/clip/__init__.py b/examples/pre-training/ernie/src/clip/__init__.py index 6484ef44..f4c56fec 100644 --- a/examples/pre-training/ernie/src/clip/__init__.py +++ b/examples/pre-training/ernie/src/clip/__init__.py @@ -13,5 +13,6 @@ # limitations under the License. from .moe_clip import ClipGradForMOEByGlobalNorm +from .moe_clip_auto import ClipGradForMOEByGlobalNormAuto -__all__ = ['ClipGradForMOEByGlobalNorm'] +__all__ = ["ClipGradForMOEByGlobalNorm", "ClipGradForMOEByGlobalNormAuto"] diff --git a/examples/pre-training/ernie/src/clip/moe_clip_auto.py b/examples/pre-training/ernie/src/clip/moe_clip_auto.py new file mode 100644 index 00000000..c82130ee --- /dev/null +++ b/examples/pre-training/ernie/src/clip/moe_clip_auto.py @@ -0,0 +1,175 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import math + +import paddle +import paddle.distributed as dist +from paddle.autograd import no_grad +from paddle.framework import core +from paddle.nn import clip +from paddle.nn.clip import ClipGradBase, _squared_l2_norm + +logger = logging.getLogger(__name__) + + +class ClipGradForMOEByGlobalNormAuto(ClipGradBase): + def __init__( + self, + clip_norm, + is_expert_param_func=None, + moe_group=None, + group_name="default_moe_group", + local_clip=False, + ): + super().__init__() + self.clip_norm = float(clip_norm) + self.group_name = group_name + self.moe_group = moe_group + if moe_group is not None and moe_group.nranks > 1: + assert ( + is_expert_param_func is not None + ), "When moe group size > 1, a function for selecting expert params must be specified." + self.is_expert_param_func = is_expert_param_func + self.stat = {} + self.local_clip = local_clip + + def __str__(self): + return "Gradient Clip By GlobalNorm, global_norm=%f" % (self.clip_norm) + + @staticmethod + def get_l2_norm_pow(params_grads, sum_dtype=None): + sum_square_list = [] + sum_square_list_fp16 = [] + sum_square_list_fp32 = [] + for p, g in params_grads: + if g is None: + continue + if getattr(p, "need_clip", True) is False: + continue + merge_grad = g + if g.type == core.VarDesc.VarType.SELECTED_ROWS: + merge_grad = clip.merge_selected_rows(g) + merge_grad = clip.get_tensor_from_selected_rows(merge_grad) + sum_square = _squared_l2_norm(merge_grad) + if sum_square.dtype == core.VarDesc.VarType.FP16: + sum_square_list_fp16.append(sum_square) + elif sum_square.dtype == core.VarDesc.VarType.FP32: + sum_square_list_fp32.append(sum_square) + else: + sum_square_list.append(sum_square.cast("float64")) + + if ( + len(sum_square_list) + len(sum_square_list_fp16) + len(sum_square_list_fp32) + == 0 + ): + return None, None + assert sum_dtype in [ + "float64", + "float32", + None, + ], "sum's type must be float64/ float32 / None" + if sum_dtype != "float64": + sum_dtype = "float64" if len(sum_square_list) > 0 else "float32" + + global_norm_var = [] + if len(sum_square_list_fp16) > 0: + global_norm_var_fp16 = paddle.add_n(sum_square_list_fp16) + global_norm_var.append(global_norm_var_fp16.astype(sum_dtype)) + if len(sum_square_list_fp32) > 0: + global_norm_var_fp32 = paddle.add_n(sum_square_list_fp32) + if sum_dtype == "float32": + global_norm_var.append(global_norm_var_fp32) + else: + global_norm_var.append(global_norm_var_fp32.astype(sum_dtype)) + if len(sum_square_list) > 0: + global_norm_var_fp64 = paddle.add_n(sum_square_list) + global_norm_var.append(global_norm_var_fp64) + global_norm_var = paddle.add_n(global_norm_var) + return global_norm_var, sum_dtype + + @no_grad() + def _dygraph_clip(self, params_grads): + normal_params_grads = [] + moe_params_grads = [] + + if self.moe_group is not None and self.moe_group.nranks > 1: + for p, g in params_grads: + if self.is_expert_param_func(p): + moe_params_grads.append((p, g)) + else: + normal_params_grads.append((p, g)) + else: + normal_params_grads = params_grads + + global_norm_var_normal, sum_dtype = self.get_l2_norm_pow(normal_params_grads) + global_norm_var_moe = None + if len(moe_params_grads) > 0: + global_norm_var_moe, _ = self.get_l2_norm_pow(moe_params_grads, sum_dtype) + if global_norm_var_moe is not None: + dist.all_reduce( + global_norm_var_moe, + op=dist.ReduceOp.SUM, + group=self.moe_group, + ) + + if global_norm_var_normal is None and global_norm_var_moe is None: + return params_grads + elif global_norm_var_normal is None: + global_norm_var = global_norm_var_moe + elif global_norm_var_moe is None: + global_norm_var = global_norm_var_normal + else: + if global_norm_var_normal.dtype != global_norm_var_moe.dtype: + global_norm_var_normal = global_norm_var_normal.astype( + global_norm_var_moe.dtype + ) + if self.local_clip: + global_norm_var = global_norm_var_normal + else: + global_norm_var = global_norm_var_normal + global_norm_var_moe + self.stat["local_grad_norm"] = math.sqrt( + global_norm_var_normal.astype("float32").item() + ) + self.stat["moe_grad_norm"] = math.sqrt( + global_norm_var_moe.astype("float32").item() + ) + self.stat["global_grad_norm"] = math.sqrt( + global_norm_var.astype("float32").item() + ) + + params_and_grads = [] + global_norm_var = paddle.sqrt(global_norm_var) + max_global_norm = paddle.full( + shape=[1], dtype=global_norm_var.dtype, fill_value=self.clip_norm + ) + clip_var = paddle.divide( + x=max_global_norm, + y=paddle.maximum(x=global_norm_var, y=max_global_norm), + ) + for p, g in params_grads: + if g is None: + continue + if getattr(p, "need_clip", True) is False: + params_and_grads.append((p, g)) + continue + clip_input = ( + clip_var.astype("float16") + if g.dtype == core.VarDesc.VarType.FP16 + else clip_var + ) + new_grad = paddle.multiply(x=g, y=clip_input.astype(g.dtype)) + params_and_grads.append((p, new_grad)) + return params_and_grads diff --git a/examples/pre-training/ernie/src/datasets/dist_data_loader.py b/examples/pre-training/ernie/src/datasets/dist_data_loader.py new file mode 100644 index 00000000..54f030f9 --- /dev/null +++ b/examples/pre-training/ernie/src/datasets/dist_data_loader.py @@ -0,0 +1,540 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from collections import deque +from collections import OrderedDict +from itertools import groupby +from functools import reduce +from dataclasses import dataclass + +import numpy as np +import paddle +from paddle.distributed import fleet +import paddle.distributed as dist +from paddle.utils.layers_utils import flatten, map_structure, pack_sequence_as + +from paddleformers.utils.batch_sampler import DistributedBatchSampler +from paddleformers.trainer.plugins.timer import get_timers +from paddleformers.utils.tools import get_env_device + +from src.utils.misc import global_training_logs + + +input_ids_for_mtp = deque() + +log = logging.getLogger(__name__) + +_MAX_DATA_DIM = 64 + + +class DummyDataset(paddle.io.Dataset): + def __len__(self): + return 0 + + +class DistDataLoader(paddle.io.DataLoader): + + def __init__( + self, + dataset, + feed_list=None, + places=None, + return_list=True, + batch_sampler=None, + batch_size=1, + shuffle=False, + drop_last=False, + collate_fn=None, + num_workers=0, + use_buffer_reader=True, + prefetch_factor=2, + use_shared_memory=True, + timeout=0, + worker_init_fn=None, + persistent_workers=False, + need_data=True, + pp_broadcast=True, + need_magic_trans=False, + ): + if dataset is None: + dataset = DummyDataset() + batch_sampler = DistributedBatchSampler(dataset, 1) + log.info("rank has no data, use Dummpy dataset") + super().__init__( + dataset=dataset, + batch_sampler=batch_sampler, + collate_fn=collate_fn, + num_workers=num_workers, + ) + self.need_magic_trans = need_magic_trans + self._hcg = fleet.get_hybrid_communicate_group() + + if self._hcg.get_pipe_parallel_world_size() > 1 and pp_broadcast: + self._pp_data_group = self._init_dataloader_comm_group() + else: + self._pp_data_group = None + + self.mp_rank = self._hcg.get_model_parallel_rank() + self.mp_group = self._hcg.get_model_parallel_group() + self.mp_src_rank = self._hcg.get_model_parallel_group_src_rank() + + self.pp_rank = self._hcg.get_stage_id() + self.dp_rank = self._hcg.get_data_parallel_rank() + sharding_rank = self._hcg.get_sharding_parallel_rank() + self._need_data = need_data + if self._need_data: + self._dataloder = paddle.io.DataLoader( + dataset, + feed_list, + places, + return_list, + batch_sampler, + batch_size, + shuffle, + drop_last, + collate_fn, + num_workers, + use_buffer_reader, + prefetch_factor, + use_shared_memory, + timeout, + worker_init_fn, + persistent_workers, + ) + + self._lazy_dataloader_iter = None + else: + log.info( + f"mp{self.mp_rank}_pp{self.pp_rank}_sharding{sharding_rank}_dp{self.dp_rank} no data needed, " + "skip init dataloader." + ) + + @property + def _dataloder_iter(self): + if self._lazy_dataloader_iter is None: + self._lazy_dataloader_iter = iter(self._dataloder) + return self._lazy_dataloader_iter + + def __len__(self): + if self._need_data: + return super().__len__() + else: + raise ValueError( + "raise error for `paddlenlp.trainer.trainer_utils.has_length`" + ) + + def _init_dataloader_comm_group(self): + topo = self._hcg._topo + parallel_comm_group = None + parallel_groups = topo.get_comm_list("pipe") + + for group in parallel_groups: + if self.need_magic_trans: + assert ( + len(group) > 2 + ), f"magic_trans need ranks in group greater than 2, but get {len(group)}" + ranks = [group[0], group[-2], group[-1]] + else: + ranks = [group[0], group[-1]] + comm_group = paddle.distributed.new_group(ranks=ranks) + if paddle.distributed.get_rank() in ranks: + parallel_comm_group = comm_group + return parallel_comm_group + + def __iter__(self): + return self + + def __next__(self): + get_timers() and get_timers()("read-raw-data").start() + if self._need_data: + data = next(self._dataloder_iter) + if "data_not_valid" in data: + global_training_logs.update( + data_not_valid=data["data_not_valid"].astype("float32").mean() + ) + ( + input_ids, + labels, + data_type, + images, + token_type_ids, + image_type_ids, + audio_input_ids, + audio_labels, + grid_thw, + inbatch_pack_offset, + position_ids, + log_prob, + ) = ( + data["input_ids"], + data["labels"], + data.get("data_type", None), + data.get("images", None), + data.get("token_type_ids", None), + data.get("image_type_ids", None), + data.get("audio_input_ids", None), + data.get("audio_labels", None), + data.get("grid_thw", None), + data.get("inbatch_pack_offset", None), + data.get("position_ids", None), + data.get("log_prob", None), + ) + assert {input_ids.dtype, labels.dtype} == {paddle.int64}, ( + f"Distloader requires dtype == `int64`, " + f"got:{[input_ids.dtype, labels.dtype]}" + ) + else: + ( + input_ids, + labels, + data_type, + images, + token_type_ids, + image_type_ids, + audio_input_ids, + audio_labels, + grid_thw, + inbatch_pack_offset, + position_ids, + log_prob, + ) = ( + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + get_timers() and get_timers()("read-raw-data").stop() + + pp_broadcast = (self._pp_data_group is None) or self.pp_rank == 0 + if self.mp_group is not None and self.mp_group.nranks > 1 and pp_broadcast: + ( + input_ids, + labels, + data_type, + images, + token_type_ids, + image_type_ids, + audio_input_ids, + audio_labels, + grid_thw, + inbatch_pack_offset, + position_ids, + log_prob, + ) = broadcast_data_obj( + [ + input_ids, + labels, + data_type, + images, + token_type_ids, + image_type_ids, + audio_input_ids, + audio_labels, + grid_thw, + inbatch_pack_offset, + position_ids, + log_prob, + ], + self.mp_src_rank, + self.mp_group, + ) + + if self._pp_data_group is not None and self._pp_data_group.nranks > 1: + ( + input_ids, + labels, + data_type, + images, + token_type_ids, + image_type_ids, + audio_input_ids, + audio_labels, + grid_thw, + inbatch_pack_offset, + position_ids, + log_prob, + ) = broadcast_data_obj( + [ + input_ids, + labels, + data_type, + images, + token_type_ids, + image_type_ids, + audio_input_ids, + audio_labels, + grid_thw, + inbatch_pack_offset, + position_ids, + log_prob, + ], + self._pp_data_group.ranks[0], + self._pp_data_group, + ) + + if self.need_magic_trans: + if input_ids is not None: + global input_ids_for_mtp + input_ids_for_mtp.append(input_ids) + + to_return = OrderedDict( + [ + ("input_ids", input_ids), + ("labels", labels), + ("data_type", data_type), + ("images", images), + ("token_type_ids", token_type_ids), + ("image_type_ids", image_type_ids), + ("audio_input_ids", audio_input_ids), + ("audio_labels", audio_labels), + ("grid_thw", grid_thw), + ("inbatch_pack_offset", inbatch_pack_offset), + ("position_ids", position_ids), + ] + ) + optional_keys = [ + "data_type", + "images", + "token_type_ids", + "image_type_ids", + "audio_input_ids", + "audio_labels", + "grid_thw", + "inbatch_pack_offset", + "position_ids", + "log_prob", + ] + none_keys = [ + k for k, v in to_return.items() if v is None and k in optional_keys + ] + for k in none_keys: + to_return.pop(k) + return to_return + + +def broadcast_data_list(data_list, datatype, comm_rank=0, comm_group=None, src_rank=0): + """ + Broadcast data from src_rank to all ranks in comm_group. + """ + size_cpu = [] + if comm_rank == 0: + for data in data_list: + size_cpu.append(len(data.shape)) + size_cpu += data.shape + size_cpu = size_cpu + [0] * (_MAX_DATA_DIM - len(size_cpu)) + size_cuda = paddle.to_tensor(size_cpu) + paddle.distributed.broadcast(size_cuda, src_rank, group=comm_group).wait() + + size_cpu = size_cuda.tolist() + i = 0 + numel = 0 + sizes = [] + while size_cpu[i] > 0: + rank = size_cpu[i] + this_size = size_cpu[i + 1 : i + 1 + rank] + numel += int(np.prod(this_size)) + sizes.append(this_size) + i += 1 + rank + + if comm_rank == 0: + assert data.dtype == datatype, ( + f"input has data type {data.dtype} which " f"is different than {datatype}" + ) + data_b = paddle.concat( + [d.to(get_env_device()).reshape([-1]) for d in data_list], 0 + ) + assert numel == sum([d.numel().item() for d in data_list]), ( + numel, + [d.numel().item() for d in data_list], + ) + else: + data_b = paddle.empty([numel], dtype=datatype).to(get_env_device()) + + paddle.distributed.broadcast(data_b, src_rank, group=comm_group).wait() + + ret = [] + offset = 0 + for size in sizes: + numel = int(np.prod(size)) + ret.append(data_b[offset : offset + numel].reshape(size)) + offset += numel + + return ret + + +@dataclass +class _DtypeSndShape: + """_summary_ + + Returns + ------- + _type_: _description_ + """ + + dtype: paddle.dtype + shape: list + + def size(self): + """_summary_ + + Returns + ------- + _type_: _description_ + """ + return reduce(lambda x, y: x * y, self.shape) + + +def split_group(grouped, split_size): + """_summary_ + + Args: + grouped (_type_): _description_ + split_size (_type_): _description_ + + Yields + ------ + _type_: _description_ + """ + ret = [] + while grouped: + if sum([r[1].size() for r in ret]) > split_size: + yield ret + ret = [] + ret.append(grouped.pop()) + if ret: + yield ret + + +def broadcast_data_obj(data, src_rank, group): + this_rank = dist.get_rank() + if this_rank == src_rank: + template = [ + map_structure( + lambda x: ( + _DtypeSndShape(dtype=x.dtype, shape=x.shape) + if x is not None + else _DtypeSndShape(dtype="", shape=[0]) + ), + data, + ) + ] + else: + template = [None] + dist.broadcast_object_list(template, src_rank, group) + template = template[0] + + temp_flat = flatten(template) + data_flat = flatten(data) + + def keyfn(i): + return str(i[1].dtype) + + ret_flat = [-1 for _ in range(len(temp_flat))] + for dtype, grouped in groupby(sorted(enumerate(temp_flat), key=keyfn), keyfn): + grouped = list(grouped) + for grouped_chunk in split_group(grouped, 2**18): + idxs = [g[0] for g in grouped_chunk] + if not dtype: + for id in idxs: + ret_flat[id] = None + continue + + data_buf_shapes = [ + reduce(lambda x, y: x * y, g[1].shape) for g in grouped_chunk + ] + if this_rank == src_rank: + data_buf = paddle.concat([data_flat[i].reshape([-1]) for i in idxs], 0) + else: + data_buf = paddle.empty( + [sum(data_buf_shapes)], dtype=grouped_chunk[0][1].dtype + ) + dist.broadcast(data_buf, src_rank, group) + + if this_rank != src_rank: + if len(data_buf_shapes) == 1: + data_buf = [data_buf] + else: + data_buf = data_buf.split(data_buf_shapes, axis=0) + for g, data_chunk in zip(grouped_chunk, data_buf): + ret_flat[g[0]] = data_chunk.reshape(g[1].shape) + + if this_rank != src_rank: + assert not [r for r in ret_flat if r is -1], ret_flat + data = pack_sequence_as(template, ret_flat) + return data + + +class DistDataLoaderAuto(DistDataLoader): + + def _init_dataloader_comm_group(self): + return self._hcg.get_pipe_parallel_group() + + def __next__(self): + data_dict = super().__next__() + + input_list = [] + if "token_type_ids" in data_dict.keys(): + + ( + input_ids, + labels, + data_type, + images, + token_type_ids, + image_type_ids, + grid_thw, + ) = ( + data_dict["input_ids"], + data_dict["labels"], + data_dict["data_type"], + data_dict.get("images", None), + data_dict["token_type_ids"], + data_dict.get("image_type_ids", None), + data_dict.get("grid_thw", None), + ) + + data_world_size = max(self._hcg.get_data_parallel_rank(), 1) * max( + self._hcg.get_sharding_parallel_rank(), 1 + ) + if images is None: + images = paddle.zeros([1, 64, 64], dtype="uint8") + has_images = paddle.full([data_world_size, 1], False, dtype="bool") + else: + raise NotImplementedError + has_images = paddle.full([data_world_size, 1], True, dtype="bool") + if image_type_ids is None: + image_type_ids = paddle.zeros_like(token_type_ids) + input_list = [ + input_ids, + labels, + data_type, + images, + token_type_ids, + image_type_ids, + has_images, + grid_thw, + ] + else: + for key, data in data_dict.items(): + input_list.append(data) + return OrderedDict([("input_ids", input_list), ("labels", [])]) diff --git a/examples/pre-training/ernie/src/lr_schedulers/__init__.py b/examples/pre-training/ernie/src/lr_schedulers/__init__.py index 77159c8e..71081f4b 100644 --- a/examples/pre-training/ernie/src/lr_schedulers/__init__.py +++ b/examples/pre-training/ernie/src/lr_schedulers/__init__.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from .cosine_lr import get_cosine_schedule_with_warmup from .wsd_lr import get_wsd_schedule_with_warmup -__all__ = ['get_wsd_schedule_with_warmup'] +__all__ = ["get_wsd_schedule_with_warmup", "get_cosine_schedule_with_warmup"] diff --git a/examples/pre-training/ernie/src/lr_schedulers/cosine_lr.py b/examples/pre-training/ernie/src/lr_schedulers/cosine_lr.py new file mode 100644 index 00000000..6059c60a --- /dev/null +++ b/examples/pre-training/ernie/src/lr_schedulers/cosine_lr.py @@ -0,0 +1,62 @@ +# !/usr/bin/env python3 +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Custom lr schedule +""" + +import math +from paddle.optimizer.lr import LambdaDecay + + +def get_cosine_schedule_with_warmup( + learning_rate: float, + num_warmup_steps: int, + num_training_steps: int, + num_cycles: float = 0.5, + last_epoch: int = -1, + min_lr: float = 0.0, +): + """ + Create a schedule with a learning rate that decreases following the values of the cosine function between the + initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the + initial lr set in the optimizer. + Args: + learning_rate (float) + The initial learning rate. It is a python float number. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + num_cycles (`float`, *optional*, defaults to 0.5): + The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 + following a half-cosine). + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + Return: + `paddle.optimizer.lr.LambdaDecay` with the appropriate schedule. + """ + + def lr_lambda(current_step): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + progress = float(current_step - num_warmup_steps) / float( + max(1, num_training_steps - num_warmup_steps) + ) + ratio = max( + 0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)) + ) + return ratio * (1 - min_lr / learning_rate) + min_lr / learning_rate + + return LambdaDecay(learning_rate, lr_lambda, last_epoch) diff --git a/examples/pre-training/ernie/src/trainers/__init__.py b/examples/pre-training/ernie/src/trainers/__init__.py index 254a42c3..477eeef4 100644 --- a/examples/pre-training/ernie/src/trainers/__init__.py +++ b/examples/pre-training/ernie/src/trainers/__init__.py @@ -17,9 +17,12 @@ PretrainingTrainer, WeightedDistributedSampler, ) +from .pretraining_trainer_auto import AutoPretrainingTrainer, AutoPreTrainingArguments __all__ = [ - 'PretrainingTrainer', - 'PreTrainingArguments', - 'WeightedDistributedSampler', + "PretrainingTrainer", + "PreTrainingArguments", + "WeightedDistributedSampler", + "AutoPretrainingTrainer", + "AutoPreTrainingArguments", ] diff --git a/examples/pre-training/ernie/src/trainers/pretraining_trainer.py b/examples/pre-training/ernie/src/trainers/pretraining_trainer.py index 01f20a6e..65477b03 100644 --- a/examples/pre-training/ernie/src/trainers/pretraining_trainer.py +++ b/examples/pre-training/ernie/src/trainers/pretraining_trainer.py @@ -93,7 +93,7 @@ FP8QuantWeightCallback, ) from src.callbacks.moe_logging_callback import MoeLoggingCallback -from src.clip import ClipGradForMOEByGlobalNorm +from src.clip import ClipGradForMOEByGlobalNormAuto from src.lr_schedulers import get_wsd_schedule_with_warmup from src.trainers.data_parallel import sync_dp_moe_params_across_sharding from src.utils.misc import global_training_logs @@ -493,7 +493,7 @@ def load_data_seq_from_cache(self): def gen_data_seq_weighted(self, num_examples, data_type=None): assert ( self.load_data_seq is False - ), "需要保证所有epoch的data_seq都从文件加载,否则下次删data_seq无法控住随机性" + ), "Ensure that the data_seq for all epochs is loaded from the file; otherwise, the randomness cannot be controlled when deleting data_seq next time." logger.info( f"generating data sequence... #non_consecutive_data_chunks={num_examples}," f" num_consecutive={self.num_consecutive}" @@ -1540,7 +1540,7 @@ def apply_decay_param_fun(x): def expert_fn(p): return getattr(p, "no_sync", False) - grad_clip = ClipGradForMOEByGlobalNorm( + grad_clip = ClipGradForMOEByGlobalNormAuto( self.args.max_grad_norm, is_expert_param_func=expert_fn, moe_group=_get_global_group(), diff --git a/examples/pre-training/ernie/src/trainers/pretraining_trainer_auto.py b/examples/pre-training/ernie/src/trainers/pretraining_trainer_auto.py new file mode 100644 index 00000000..68624cf6 --- /dev/null +++ b/examples/pre-training/ernie/src/trainers/pretraining_trainer_auto.py @@ -0,0 +1,659 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""AutoPretrainingTrainer""" + +__all__ = [ + "AutoPretrainingTrainer", +] + + +import sys +import re +import os +import json +import contextlib +from typing import Optional +from dataclasses import dataclass, field +import time +import math +import logging +from functools import partial + + +import paddle +import paddle.nn as nn +from paddle.io import DataLoader +import paddle.amp.auto_cast as autocast +from paddle.distributed.communication.group import _get_global_group + +from paddleformers.trainer import ( + speed_metrics, +) + +from paddleformers.trainer.auto_trainer import AutoTrainer + + +from paddleformers.utils.batch_sampler import ( + DistributedBatchSampler as PaddleNLPDistributedBatchSampler, +) + + +from paddleformers.trainer.utils import add_start_docstrings +from paddleformers.trainer.trainer_callback import PrinterCallback +from paddle.distributed import fleet +import paddle.distributed as dist + + +from src.lr_schedulers import get_cosine_schedule_with_warmup +from src.utils_auto.training_utils import ( + reset_per_device_batch_size, +) +from src.callbacks_auto import ( + TensorBoardCallback, + LoggingCallback, + StopperCallback, +) +from src.datasets.dist_data_loader import ( + DistDataLoaderAuto, +) +from src.clip import ClipGradForMOEByGlobalNormAuto + + +logger = logging.getLogger(__name__) + +try: + from paddleformers.trainer import AutoTrainingArguments +except ImportError: + from paddleformers.trainer import TrainingArguments as AutoTrainingArguments + + logger.warning("paddlenlp.trainer.AutoTrainingArguments CANNOT import!") + logger.warning("Use TrainingArguments as an alternative but will lose some args!") + + +DATATYPE_2_ID = {"mm": 0, "lm": 1, "audio": 2} + + +@dataclass +@add_start_docstrings(AutoTrainingArguments.__doc__) +class AutoPreTrainingArguments(AutoTrainingArguments): + + multimodal: bool = field( + default=False, metadata={"help": "whether training with multimodal"} + ) + model_name_or_path: str = field( + default=None, + metadata={ + "help": "Path to pretrained model or model identifier from " + "https://paddlenlp.readthedocs.io/zh/latest/model_zoo/transformers.html" + }, + ) + + prefetch_factor: int = field( + default=2, + metadata={"help": "global random seed factor."}, + ) + eval_iters: int = field( + default=-1, + metadata={"help": "eval iteration for every evaluation."}, + ) + + min_lr: float = field( + default=0.0, + metadata={"help": "minus learning rate"}, + ) + + input_dir: str = field(default=None, metadata={"help": "data path"}) + split: str = field( + default="949,50,1", metadata={"help": "Train/valid/test data split ratio"} + ) + + max_seq_length: int = field( + default=512, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + global_batch_size: int = field( + default=-1, + metadata={ + "help": "if `global_batch_size` and `per_device_train_batch_size` is provied, " + "`gradient_accumulation_steps` will be ignored" + }, + ) + + tokenizer_name: Optional[str] = field( + default=None, + metadata={ + "help": "Pretrained tokenizer name or path if not the same as model_name" + }, + ) + + sequence_parallel: Optional[int] = field( + default=0, + metadata={}, + ) + + virtual_pp_degree: Optional[int] = field( + default=1, + metadata={ + "help": "vpp", + }, + ) + + use_async_save: Optional[bool] = field( + default=False, + metadata={"help": "Whether to use async_save instead of paddle.save."}, + ) + pre_alloc_memory: float = field( + default=0.0, + metadata={ + "help": "Pre-allocate one specific-capacity empty tensor " + "and release it for avoiding memory fragmentation" + }, + ) + + moe_group: Optional[str] = field( + default="dp", + metadata={ + "help": "The communication group of moe currently supports `dp|sharding|mp|dummy`" + }, + ) + use_moe: Optional[bool] = field( + default=False, metadata={"help": "Temporary alternative to expert parallelism."} + ) + moe_use_all2all: Optional[bool] = field( + default=False, + metadata={"help": "Whether to use the all2all communication method."}, + ) + log_global_grad_norm: Optional[bool] = field( + default=False, + metadata={ + "help": "Print the global gradient norm, which only takes effect when `enable_global_training_logs` is enabled.." + }, + ) + + multi_token_pred_depth: Optional[int] = field( + default=0, + metadata={}, + ) + + lr_scheduler: str = field( + default="cosine", + metadata={ + "help": "The scheduler type to use. suppor linear, cosine, constant, constant_with_warmup" + }, + ) + + moe_gate_lr_ratio: float = field( + default=None, + metadata={ + "help": ( + "When enabling MoE, apply special handling to the learning rate (LR) of the gate/router." + ) + }, + ) + vit_lr_ratio: float = field( + default=None, + metadata={ + "help": ( + "When enabling ViT training, apply special handling to the learning rate (LR) of ViT." + ) + }, + ) + + pipeline_schedule_mode: str = field( + default="1F1B", + metadata={"help": "The pipeline schedule mode, support 1F1B and VPP"}, + ) + virtual_pipeline_seg_method: str = field( + default="ErnieDecoderLayerAuto", + metadata={"help": "The seg method of spliting pp layer for virtual pipeline."}, + ) + + model_type: Optional[str] = field( + default="ernie", + metadata={"help": "Only support for ernie pre-training for now."}, + ) + n_microbatches: int = field( + default=1, + metadata={"help": "Control the num of microbatches in one pp step."}, + ) + + @property + def need_data(self): + return self.pipeline_parallel_rank == 0 and self.tensor_parallel_rank == 0 + + @property + def reeao_dataset_world_size(self): + return super().dataset_world_size + + def __post_init__(self): + super().__post_init__() + + assert ( + self.global_batch_size + == self.per_device_train_batch_size + * self.gradient_accumulation_steps + * max(self.sharding_parallel_degree, 1) + * max(self.data_parallel_degree, 1) + ), ( + f"`gbs` should be equal to `lbs * acc * (dp_degree or sd_degree)`, " + f"but got gbs={self.global_batch_size}, " + f"lbs={self.per_device_train_batch_size}, " + f"acc={self.gradient_accumulation_steps}, " + f"dp_degree={max(self.data_parallel_degree, 1)}, " + f"sd_degree={max(self.sharding_parallel_degree, 1)}" + ) + + if self.global_batch_size > 0: + micro_bsz, acc_steps = reset_per_device_batch_size( + self.global_batch_size, + self.per_device_train_batch_size, + self.dataset_world_size, + ) + logger.info( + f"global_batch={self.global_batch_size} micro-bsz:{micro_bsz}, accumulate_steps:{acc_steps}" + ) + if ( + acc_steps != 1 + and self.gradient_accumulation_steps != 1 + and acc_steps != self.gradient_accumulation_steps + ): + raise ValueError( + f"global_accumulation_steps={self.gradient_accumulation_steps}" + f"& global_batch={self.global_batch_size} are both set" + ) + self.per_device_train_batch_size, self.gradient_accumulation_steps = ( + micro_bsz, + acc_steps, + ) + + self.max_gradient_accumulation_steps = self.gradient_accumulation_steps + + if self.pipeline_parallel_degree > 1: + self.per_device_eval_batch_size = ( + self.per_device_train_batch_size * self.gradient_accumulation_steps + ) + logger.warn( + f"eval_batch_size set to {self.per_device_eval_batch_size} in Pipeline Parallel!" + ) + user_defined_strategy = fleet.fleet._user_defined_strategy + user_defined_strategy.strategy.pipeline_configs.accumulate_steps = ( + self.gradient_accumulation_steps + ) + + self.max_gradient_accumulation_steps = self.gradient_accumulation_steps + logger.info(f"fixing pp configs: {user_defined_strategy.pipeline_configs}") + else: + self.per_device_eval_batch_size = self.per_device_train_batch_size + logger.warn(f"eval_batch_size set to {self.per_device_eval_batch_size}") + + if self.sharding_parallel_degree > 1: + sharding_parallel_config = ( + set(self.sharding_parallel_config.split(" ")) + if self.sharding_parallel_config + else set() + ) + sharding_comm_overlap_non_pp = ( + True + if "shardingv1_comm_overlap" in sharding_parallel_config + or "sharding_comm_overlap" in sharding_parallel_config + else False + ) + if sharding_comm_overlap_non_pp: + assert hasattr(fleet.fleet, "_user_defined_strategy") + user_defined_strategy = fleet.fleet._user_defined_strategy + user_defined_strategy.hybrid_configs[ + "sharding_configs" + ].accumulate_steps = self.gradient_accumulation_steps + + if hasattr(fleet.fleet, "_user_defined_strategy"): + user_defined_strategy = fleet.fleet._user_defined_strategy + if ( + hasattr(user_defined_strategy, "hybrid_configs") + and "sharding_configs" in user_defined_strategy.hybrid_configs + ): + sd_configs = user_defined_strategy.hybrid_configs["sharding_configs"] + if sd_configs.comm_overlap: + assert self.global_batch_size % self.dataset_world_size == 0, ( + f"global_batch_size[{self.global_batch_size}] should be divisible by " + f"dataset_world_size[{self.dataset_world_size}]" + ) + lbs = self.global_batch_size // self.dataset_world_size + assert lbs % self.per_device_train_batch_size == 0, ( + f"local_batch_size[{lbs}] should be divisible by " + f"per_device_train_batch_size[{self.per_device_train_batch_size}]" + ) + assert ( + lbs // self.per_device_train_batch_size + == sd_configs.accumulate_steps + ), ( + f"local_batch_size[{lbs}] should be equal to " + f"accumulate_steps[{sd_configs.accumulate_steps}] * " + f"per_device_train_batch_size[{self.per_device_train_batch_size}]" + ) + + +class AutoPretrainingTrainer(AutoTrainer): + + def __init__(self, args=None, model=None, callbacks=[], **kwargs): + callbacks = [ + LoggingCallback(), + StopperCallback(), + TensorBoardCallback( + args, model=model, log_tokens_per_step=True, log_flops_per_step=False + ), + ] + callbacks + + args.use_async_save = ( + args.use_async_save and args.save_sharded_model and args.load_sharded_model + ) + super().__init__(args=args, model=model, callbacks=callbacks, **kwargs) + + def get_numel_item(p): + item = p.numel().item() + return item if item else 0 + + model_numel = sum( + get_numel_item(p) + for n, p in model.named_parameters() + if not p.stop_gradient and "embeddings" not in n and "embed_tokens" not in n + ) + numel_tensor = paddle.to_tensor(model_numel) + dist.all_reduce(numel_tensor) + self.model_numel = numel_tensor.item() // self.args.dataset_world_size + + self.pop_callback(PrinterCallback) + + def autocast_smart_context_manager(self): + + if self.enable_autocast_context_manager: + black = [ + "reduce_sum", + "c_softmax_with_cross_entropy", + "elementwise_div", + "sin", + "cos", + ] + white = [ + "lookup_table", + "lookup_table_v2", + "flash_attn", + "flash_attn_v1", + "matmul", + "matmul_v2", + "fused_gemm_epilogue", + ] + if self.args.bf16 and self.args.fp16_opt_level == "O2": + black.append("c_embedding") + + ctx_manager = autocast( + True, + custom_black_list=black, + custom_white_list=white, + level=self.args.fp16_opt_level, + dtype=self.amp_dtype, + ) + else: + ctx_manager = ( + contextlib.nullcontext() + if sys.version_info >= (3, 7) + else contextlib.suppress() + ) + + return ctx_manager + + def evaluate( + self, eval_dataset=None, ignore_keys=None, metric_key_prefix: str = "eval" + ): + + self.model_wrapped.accumulate_steps = self.args.gradient_accumulation_steps + eval_dataloader = self.get_eval_dataloader(eval_dataset) + + start_time = time.time() + compute_metrics = self.compute_metrics + eval_loop = self.evaluation_loop + + output = eval_loop( + eval_dataloader, + description="Evaluation", + prediction_loss_only=True if compute_metrics is None else None, + ignore_keys=ignore_keys, + max_eval_iters=self.args.eval_iters, + ) + + total_batch_size = self.args.eval_batch_size * self.args.world_size + output.metrics.update( + speed_metrics( + metric_key_prefix, + start_time, + num_samples=output.num_samples, + num_steps=math.ceil(output.num_samples / total_batch_size), + ) + ) + + self.log(output.metrics) + + self.control = self.callback_handler.on_evaluate( + self.args, self.state, self.control, output.metrics + ) + return output.metrics + + def prediction_pipeline_step( + self, model, inputs, prediction_loss_only, ignore_keys + ): + + loss, _, labels = super().prediction_pipeline_step( + model, inputs, prediction_loss_only, ignore_keys + ) + num_tokens = (labels != self.tokenizer.ignored_index).sum().item() + loss_avg = loss * self.model_wrapped.accumulate_steps / num_tokens + return loss_avg, loss, labels + + def _get_train_sampler(self) -> Optional[paddle.io.Sampler]: + return PaddleNLPDistributedBatchSampler( + self.train_dataset, + batch_size=self.args.per_device_train_batch_size, + shuffle=False, + num_replicas=self.args.dataset_world_size, + rank=self.args.dataset_rank, + drop_last=self.args.dataloader_drop_last, + ) + + def get_train_dataloader(self): + + if self.args.need_data and self.train_dataset is None: + raise ValueError("Trainer: training requires a train_dataset.") + _DataLoader = partial( + DistDataLoaderAuto, + need_data=self.args.need_data, + pp_broadcast=True, + ) + + train_dataset = self.train_dataset + if self._is_iterable_dataset(train_dataset): + return DataLoader( + train_dataset, + batch_size=None, + collate_fn=self.data_collator, + num_workers=self.args.dataloader_num_workers, + use_shared_memory=True, + prefetch_factor=self.args.prefetch_factor, + ) + if self.args.need_data: + train_sampler = self._get_train_sampler() + else: + train_sampler = None + return _DataLoader( + train_dataset, + batch_sampler=train_sampler, + collate_fn=self.data_collator, + num_workers=self.args.dataloader_num_workers, + prefetch_factor=self.args.prefetch_factor, + ) + + def _maybe_log_save_evaluate( + self, tr_loss, model, epoch, ignore_keys_for_eval, **kwargs + ): + super()._maybe_log_save_evaluate( + tr_loss, model, epoch, ignore_keys_for_eval, **kwargs + ) + return + + def create_scheduler(self, num_training_steps): + + if self.args.warmup_steps > 0: + warmup = self.args.warmup_steps + else: + warmup = int(self.args.warmup_ratio * num_training_steps) + self.lr_scheduler = get_cosine_schedule_with_warmup( + self.args.learning_rate, + warmup, + self.args.max_steps, + min_lr=self.args.min_lr if self.args.min_lr else 0.0, + ) + + return self.lr_scheduler + + def create_optimizer(self, lr_scheduler=None): + """ + Setup the optimizer. + + We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the + Trainer's init through `optimizers`, or subclass and override this method in a subclass. + """ + optimizer_params = self.model.parameters() + if self.optimizer is None: + + def need_decay(name): + if ( + name == "ernie.norm.weight" + and self.args.pipeline_parallel_degree > 1 + ): + return True + return not any(nd in name for nd in ["bias", "norm"]) + + decay_parameters = [ + p.name for n, p in self.model.named_parameters() if need_decay(n) + ] + + def apply_decay_param_fun(x): + return x in decay_parameters + + optimizer_cls, optimizer_kwargs = AutoTrainer.get_optimizer_cls_and_kwargs( + self.args + ) + + if ( + self.args.use_moe + and not self.args.use_hybrid_parallel + and not self.args.enable_auto_parallel + ): + logger.info("using moe Global clip") + + def expert_fn(p): + return getattr(p, "no_sync", False) + + grad_clip = ClipGradForMOEByGlobalNormAuto( + self.args.max_grad_norm, + is_expert_param_func=expert_fn, + moe_group=_get_global_group(), + local_clip=False, + ) + else: + grad_clip = ( + nn.ClipGradByGlobalNorm(self.args.max_grad_norm) + if self.args.max_grad_norm > 0 + else None + ) + + self.static_name_to_dyg_name = { + p.name: n for n, p in self.model.state_dict().items() + } + gate_pattern = re.compile(r"ernie\.layers\.0\.mlp\.gate\.weight") + + def lr_ratio_fn(param): + if param.name in self.static_name_to_dyg_name.keys(): + name = self.static_name_to_dyg_name[param.name] + if self.args.moe_gate_lr_ratio is not None and gate_pattern.match( + name + ): + logger.info( + f"apply moe_gate_lr_ratio to {name}, ratio={self.args.moe_gate_lr_ratio}" + ) + return float(self.args.moe_gate_lr_ratio) + + return 1.0 + + self.optimizer = optimizer_cls( + learning_rate=( + self.lr_scheduler if lr_scheduler is None else lr_scheduler + ), + apply_decay_param_fun=apply_decay_param_fun, + parameters=optimizer_params, + weight_decay=self.args.weight_decay, + grad_clip=grad_clip, + multi_precision=True, + lr_ratio=( + lr_ratio_fn if self.args.moe_gate_lr_ratio is not None else None + ), + **optimizer_kwargs, + ) + + self.static_name_to_dyg_name = { + p.name: n for n, p in self.model.named_parameters() + } + + return self.optimizer + + def save_model(self, output_dir=None): + + super().save_model(output_dir) + if self.args.should_save: + with open( + os.path.join(output_dir, "static_name_to_dyg_name.json"), "w" + ) as of: + of.write(json.dumps(self.static_name_to_dyg_name)) + + def _get_meshes_for_loader(self): + def _get_mesh(pp_idx=0): + return self.global_mesh.get_mesh_with_dim("pp")[pp_idx] + + meshes = [] + if self.args.pipeline_parallel_degree > 1: + # input_ids + meshes.append( + [ + _get_mesh(0), + _get_mesh(-1), + ] + ) + # labels + meshes.append(_get_mesh(self.args.pipeline_parallel_degree - 1)) + else: + meshes.append(_get_mesh(0)) + return meshes + + def _wrap_for_dist_loader(self, train_dataloader): + self.dense_tensor_idx = None + dist_loader = dist.shard_dataloader( + dataloader=train_dataloader, + meshes=self._get_meshes_for_loader(), + shard_dims="dp", + is_dataset_splitted=True, + ) + dist_loader._input_keys = ["input_ids", "labels"] + return dist_loader diff --git a/examples/pre-training/ernie/src/utils_auto/__init__.py b/examples/pre-training/ernie/src/utils_auto/__init__.py new file mode 100644 index 00000000..0eb015e8 --- /dev/null +++ b/examples/pre-training/ernie/src/utils_auto/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .logging import logger, setup_logger_output_file + +__all__ = ["logger", "setup_logger_output_file"] diff --git a/examples/pre-training/ernie/src/utils_auto/logging.py b/examples/pre-training/ernie/src/utils_auto/logging.py new file mode 100644 index 00000000..e43daf69 --- /dev/null +++ b/examples/pre-training/ernie/src/utils_auto/logging.py @@ -0,0 +1,54 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import sys +from pathlib import Path + +from paddleformers.utils.log import logger as paddlenlp_logger + +hdl = logging.StreamHandler(sys.stderr) +formatter = logging.Formatter( + fmt="[%(levelname)s] %(asctime)s [%(filename)12s:%(lineno)5d]: %(message)s" +) +hdl.setFormatter(formatter) +logger = logging.getLogger() +logger.handlers = [hdl] + +bce_log = logging.getLogger("baidubce") +bce_log.handlers = [] +bce_log.propagate = False +logger.setLevel(10) + +bce_bns_proxy_log = logging.getLogger("bce_bns_proxy.wrapper") +bce_bns_proxy_log.disabled = True +filelock_log = logging.getLogger("filelock") +filelock_log.disabled = True + +paddlenlp_logger.logger.handlers = [] +paddlenlp_logger.logger.propagate = True + + +def setup_logger_output_file(outputpath, local_rank): + logdir = Path(outputpath) / "log" + logdir.mkdir(exist_ok=True) + file_hdl = logging.FileHandler( + logdir / f"workerlog.{local_rank}", mode="a", encoding="utf-8" + ) + formatter = logging.Formatter( + fmt=f"[%(levelname)s] %(asctime)s [%(filename)12s:%(lineno)5d][rank-{local_rank}]: %(message)s" + ) + file_hdl.setFormatter(formatter) + hdl.setFormatter(formatter) + logger.handlers = [hdl, file_hdl] diff --git a/examples/pre-training/ernie/src/utils_auto/misc.py b/examples/pre-training/ernie/src/utils_auto/misc.py new file mode 100644 index 00000000..418bc180 --- /dev/null +++ b/examples/pre-training/ernie/src/utils_auto/misc.py @@ -0,0 +1,214 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import logging +import re + +import numpy as np +import paddle +import paddle.distributed as dist + +logger = logging.getLogger(__name__) + +try: + from models.sequence_parallel_utils import get_async_loader + from paddle.incubate.tensor.manipulation import async_offload +except ImportError: + get_async_loader = async_offload = None + +__all__ = ("global_training_logs",) + +ZERO = paddle.zeros([], dtype="float32") + + +class SmoothedValue: + def __init__( + self, + skip_zero, + ): + self.total = 0.0 + self.count = 0 + self._skip_zero = skip_zero + + @paddle.no_grad() + def update(self, value): + if isinstance(value, paddle.Tensor): + value = value.astype("float32").detach() + if value.shape == [1]: + value = value.squeeze() + self.count += (value != ZERO).astype("int64") if self._skip_zero else 1 + else: + self.count += 1 + self.total += value + + @property + def global_avg(self): + return self.total / max(self.count, 1e-6) + + def reset(self): + self.total = 0.0 + self.count = 0 + + +class TrainingLogs: + _instance = None + + def __new__(cls, *args, **kw): + if cls._instance is None: + cls._instance = object.__new__(cls, *args, **kw) + return cls._instance + + def __init__(self): + self.meters = {} + self.snapshot = None + self._global_meters_keys = [] + self.trainer = None + self.logging_interval = None + self._skip_zero_keys = [] + + def set_trainer_interval(self, trainer, logging_interval): + self.trainer = trainer + self.logging_interval = logging_interval + + @property + def global_meters_keys(self): + return self._global_meters_keys + + @global_meters_keys.setter + def global_meters_keys(self, lst): + self._global_meters_keys = lst + + def enable_skip_zero(self, keys=[]): + logger.info("global_training_logs: use skip zero") + self._skip_zero_keys = keys + for m in self.meters.keys(): + for k in keys: + if re.match(k, m): + m._skip_zero = True + + def update(self, **kwargs): + for k, v in kwargs.items(): + self[k] = v + + def is_enabled(self): + return ( + self.trainer is None + or (self.trainer.state.global_step + 1) % self.logging_interval == 0 + ) + + def __setitem__(self, k, v): + skip_zero = False + for skip_k in self._skip_zero_keys: + if re.match(skip_k, k): + skip_zero = True + metric = self.meters.setdefault(k, SmoothedValue(skip_zero=skip_zero)) + metric.update(v) + + def __getitem__(self, v): + return self.meters[v] + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{attr}'" + ) + + def dict(self, use_async=False): + avg_metric = { + k: v.global_avg + for k, v in self.meters.items() + if k not in self.global_meters_keys + } + + if self.global_meters_keys: + tensor_lst = [] + for k in self.global_meters_keys: + v = self.meters[k].global_avg if k in self.meters else -100 + tensor_lst.append(paddle.to_tensor(v, "float32")) + gathered_v = [] + dist.gather(paddle.stack(tensor_lst), gathered_v, 0) + if gathered_v: + for i, k in enumerate(self.global_meters_keys): + avg_metric[k] = np.mean( + [t[i] for t in gathered_v if t[i] != -100] + ).item() + + if not use_async: + ret = { + k: v.item() if isinstance(v, paddle.Tensor) else v + for k, v in avg_metric.items() + } + global_info = {k: v for k, v in ret.items() if k in self.global_meters_keys} + ret = { + k: v + for k, v in ret.items() + if (k not in self.global_meters_keys) + and ((not self.meters[k]._skip_zero) or v != 0.0) + } + return ret, global_info + assert get_async_loader is not None, "async logging requires latest paddle" + if not avg_metric: + return lambda: ({}, {}) + keys, values = zip(*avg_metric.items()) + tensor_list = [ + (i, t) for i, t in enumerate(values) if isinstance(t, paddle.Tensor) + ] + if tensor_list: + async_loader = get_async_loader() + tensor_id, tensor_list = zip(*tensor_list) + tensor_list = paddle.stack(tensor_list) + tensor_list_cpu, task = async_offload(tensor_list, async_loader) + else: + task = None + + def _ret(): + nonlocal task, tensor_list_cpu, values + values = list(values) + if task: + task.cpu_wait() + for i, val in zip(tensor_id, tensor_list_cpu.tolist()): + values[i] = val + ret = dict(zip(keys, values)) + global_info = {k: v for k, v in ret.items() if k in self.global_meters_keys} + ret = { + k: v + for k, v in ret.items() + if (k not in self.global_meters_keys) + and ((not self.meters[k]._skip_zero) or v != 0.0) + } + return ret, global_info + + return _ret + + def reset(self): + for k in list(self.meters.keys()): + self.meters[k].reset() + self.meters.pop(k) + + def take_snapshot(self): + self.snapshot = copy.deepcopy(self.meters) + + def restore_snapshot(self): + assert ( + self.snapshot is not None + ), "you should use take_snapshot before restore_snapshot" + self.meters = copy.deepcopy(self.snapshot) + self.snapshot = None + + +global_training_logs = TrainingLogs() diff --git a/examples/pre-training/ernie/src/utils_auto/training_utils.py b/examples/pre-training/ernie/src/utils_auto/training_utils.py new file mode 100644 index 00000000..027606fc --- /dev/null +++ b/examples/pre-training/ernie/src/utils_auto/training_utils.py @@ -0,0 +1,43 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +logger = logging.getLogger(__name__) + + +def reset_per_device_batch_size( + global_batch_size, per_device_train_batch_size, dataset_world_size +): + assert ( + global_batch_size % dataset_world_size == 0 + ), f"global_bsz={global_batch_size} not evenly divided by world_size={dataset_world_size}" + batch_per_device = global_batch_size // dataset_world_size + if batch_per_device < per_device_train_batch_size: + gradient_accumulation_steps = 1 + per_device_train_batch_size = batch_per_device + logger.info( + f"reset `per_device_train_batch_size` to {per_device_train_batch_size}, global_batch_size={global_batch_size }, " + f"dp_worldsize={ dataset_world_size}, accumulate_steps={gradient_accumulation_steps} " + ) + else: + assert ( + batch_per_device % per_device_train_batch_size == 0 + ), f"global_bsz={global_batch_size} not evenly divided by world_size={dataset_world_size}, batch_per_device={batch_per_device}" + gradient_accumulation_steps = batch_per_device // per_device_train_batch_size + logger.info( + f"per_device_train_batch_size={per_device_train_batch_size}, global_batch_size={global_batch_size }, " + f"dp_worldsize={dataset_world_size}, accumulate_steps={gradient_accumulation_steps} " + ) + return per_device_train_batch_size, gradient_accumulation_steps diff --git a/examples/pre-training/model_configs_auto/model_config.json b/examples/pre-training/model_configs_auto/model_config.json new file mode 100644 index 00000000..f552d11f --- /dev/null +++ b/examples/pre-training/model_configs_auto/model_config.json @@ -0,0 +1,66 @@ +{ + "architectures": [ + "ErnieForCausalLM" + ], + "bos_token_id": 0, + "eos_token_id": 1, + "hidden_act": "silu", + "hidden_size": 8192, + "intermediate_size": 28672, + "initializer_range": 0.00482174, + "max_sequence_length": 4096, + "max_position_embeddings": 4096, + "model_type": "ernie_pp", + "num_attention_heads": 64, + "num_key_value_heads": 8, + "num_hidden_layers": 4, + "pad_token_id": -1, + "rms_norm_eps": 1e-05, + "torch_dtype": "float16", + "transformers_version": "4.27.0.dev0", + "use_cache": true, + "vocab_size": 100352, + "rope_theta": 10000, + "use_recompute": false, + "use_recompute_attn": false, + "use_recompute_moe": false, + "use_recompute_loss_fn": false, + "use_rmsnorm": true, + "fuse_rms_norm": true, + "use_bias": false, + "use_fast_ln": true, + "fuse_attn_ffn": true, + "fuse_linear": true, + "rope_reorder": false, + "fuse_rope": true, + "fuse_swiglu": true, + "fuse_gate_detach_matmul": true, + "remove_tail_layer": 2, + "refined_recompute": { + "mlp_row_ln": -1, + "flash_attn": -1, + "attention_row_ln": -1, + "attention_column_ln": 2, + "mlp_column_ln": 0 + }, + "moe_num_experts": 16, + "moe_num_shared_experts": 0, + "moe_layer_start_index": 2, + "moe_group_experts": false, + "moe_intermediate_size": 3584, + "moe_capacity": [8,8,8], + "moe_gate": "top2_fused", + "moe_gate_scale": false, + "moe_gate_detach": 1.0, + "moe_k": 8, + "moe_aux_loss_lambda": 1e-5, + "moe_group_orthogonal_loss": true, + "moe_orthogonal_loss_lambda": 0.0, + "moe_z_loss_lambda": 0.0, + "moe_layer_interval": 1, + "z_loss_lambda": 0, + "using_precision_check": false, + "use_ep_comm_overlap": true, + "moe_use_all2all": true, + "tie_word_embeddings": true +} diff --git a/examples/pre-training/models/ernie/__init__.py b/examples/pre-training/models/ernie/__init__.py index b00b0579..97095731 100644 --- a/examples/pre-training/models/ernie/__init__.py +++ b/examples/pre-training/models/ernie/__init__.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .configuration import ErnieMoEConfig -from .modeling_pp import ErnieMoEForCausalLMPipe -__all__ = ['ErnieMoEConfig', 'ErnieMoEForCausalLMPipe'] +from .configuration import * # noqa +from .modeling import * # noqa +from .modeling_auto import * # noqa +from .modeling_auto_pp import * # noqa diff --git a/examples/pre-training/models/ernie/configuration_auto.py b/examples/pre-training/models/ernie/configuration_auto.py new file mode 100644 index 00000000..fc858468 --- /dev/null +++ b/examples/pre-training/models/ernie/configuration_auto.py @@ -0,0 +1,728 @@ +# !/usr/bin/env python3 + +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Ernie model configuration""" +import logging +import json +from typing import Union +import paddle.distributed.communication.group + +from paddleformers.transformers.configuration_utils import PretrainedConfig + +logger = logging.getLogger(__name__) + +__all__ = [ + "ERNIE_PRETRAINED_INIT_CONFIGURATION", + "ErnieMoEConfig", +] + +ERNIE_PRETRAINED_INIT_CONFIGURATION = { + "ernie/tiny-random-ernie": { + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 11008, + "max_position_embeddings": 2048, + "model_type": "ernie", + "num_attention_heads": 2, + "num_hidden_layers": 2, + "rms_norm_eps": 1e-06, + "vocab_size": 32000, + "bos_token_id": 1, + "eos_token_id": 2, + "pad_token_id": 0, + "use_cache": False, + "use_recompute": False, + "use_flash_attn": True, + "use_pure_fp16": False, + }, +} + + +class ErnieConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`~ErnieModel`]. It is used to instantiate an Ernie + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Ernie-7B. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Ernie model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`~ErnieModel`] or [`~TFErnieModel`]. + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings(`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + Example: + ```python + >>> from paddleformers.transformer import ErnieModel, ErnieConfig + + >>> # Initializing a Ernie ernie-7b style configuration + >>> configuration = ErnieConfig() + + >>> # Initializing a model from the ernie-7b style configuration + >>> model = ErnieModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "ernie" + attribute_map = { + "n_positions": "max_position_embeddings", + "n_embd": "hidden_size", + "n_layer": "num_hidden_layers", + "n_head": "num_attention_heads", + "n_inner": "intermediate_size", + "activation_function": "hidden_act", + } + pretrained_init_configuration = ERNIE_PRETRAINED_INIT_CONFIGURATION + + def __init__( + self, + vocab_size=32000, + hidden_size=768, + intermediate_size=11008, + max_position_embeddings=32768, + num_hidden_layers=2, + num_attention_heads=2, + head_dim=None, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=False, + use_flash_attn=True, + use_mem_eff_attn=False, + use_flash_attn_with_mask=False, + use_recompute=False, + use_recompute_attn=False, + recompute_use_reentrant=False, + use_rmsnorm=True, + z_loss_lambda=None, + fuse_rms_norm=False, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + fuse_attn_ffn=False, + fuse_swiglu=False, + use_bias=False, + expert_mlp_use_bias=None, + rope_reorder=True, + rope_theta=10000, + fuse_rope=False, + use_fast_ln=False, + weight_share_add_bias=True, + fuse_linear=False, + seqlen=False, + ignored_index=-100, + remove_tail_layer=False, + use_recompute_lm_head=False, + use_recompute_loss_fn=False, + use_recompute_mtp=False, + use_recompute_dnd=False, + selective_no_recompute_num=0, + use_mp_gathered_weight=False, + refined_recompute=dict(), + attention_probs_dropout_prob=0.0, + hidden_dropout_prob=0.0, + compression_ratio: float = 1.0, + quant_bits=-1, + num_key_value_heads=None, + submatrix_parallel=False, + submatrix_parallel_low_memory=True, + use_sparse_head_and_loss_fn=False, + using_dynamic_sequence_length=False, + micro_batch_size=-1, + using_precision_check=False, + use_qk_norm=False, + use_tpsp_comm_overlap=False, + offload_pp_data_chunk_size=0, + use_fused_head_loss_fn=False, + use_recompute_resampler=False, + resampler_fuse_rms_norm=False, + token_loss_equal_weight=False, + token_balance_loss=False, + token_balance_seqlen=False, + use_fp8=False, + fp8_configs=dict(), + use_fp8_mlp=False, + fp8_mem_configs=dict(), + fp8_fused_ops_configs=dict(), + drop_before_deepep=False, + deepep_drop_padding=False, + disable_pipeline_warmup=False, + skip_align_position_id=False, + rope_3d=False, + freq_allocation=0, + moe_layer_feed_fake_token=False, + decoderlayer_act_offload_settings={"type": "", "value": ""}, + loss_subbatch_seqlen=32768, + gate_force_zero_padding_grad=False, + recompute_num_layers=None, + use_combine_before_a2a=False, + use_quant_before_a2a=False, + rope_yarn_config={}, + **kwargs, + ): + if "tie_word_embeddings" not in kwargs: + kwargs["tie_word_embeddings"] = False + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs, + ) + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.max_position_embeddings = max_position_embeddings + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.head_dim = head_dim + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.use_recompute_attn = use_recompute_attn + if use_recompute_attn: + logger.warning("set `use_recompute_attn`=True, disabling `use_recompute`") + use_recompute = False + self.use_recompute = use_recompute + self.recompute_num_layers = ( + recompute_num_layers + if recompute_num_layers is not None + else num_hidden_layers + ) + self.use_flash_attn = use_flash_attn + self.recompute_use_reentrant = recompute_use_reentrant + self.use_mem_eff_attn = use_mem_eff_attn + self.use_flash_attn_with_mask = use_flash_attn_with_mask + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.fuse_attn_ffn = fuse_attn_ffn + self.fuse_swiglu = fuse_swiglu + self.fuse_rms_norm = fuse_rms_norm + self.use_rmsnorm = use_rmsnorm + self.z_loss_lambda = z_loss_lambda + self.using_dynamic_sequence_length = using_dynamic_sequence_length + if using_dynamic_sequence_length: + assert ( + micro_batch_size > 0 + ), "micro_batch_size should be set when using_dynamic_sequence_length" + self.micro_batch_size = micro_batch_size + self.using_precision_check = using_precision_check + self.use_qk_norm = use_qk_norm + + self.seqlen = seqlen + self.use_bias = use_bias + self.weight_share_add_bias = weight_share_add_bias + self.rope_reorder = rope_reorder + self.rope_yarn_config = rope_yarn_config + self.rope_theta = rope_theta + self.fuse_rope = fuse_rope + self.use_fast_ln = use_fast_ln + + self.fuse_linear = fuse_linear + self.ignored_index = ignored_index + self.remove_tail_layer = remove_tail_layer + self.use_recompute_lm_head = use_recompute_lm_head + self.use_recompute_loss_fn = use_recompute_loss_fn + self.use_recompute_mtp = use_recompute_mtp + self.use_recompute_dnd = use_recompute_dnd + + self.use_mp_gathered_weight = use_mp_gathered_weight + self.selective_no_recompute_num = selective_no_recompute_num + + self.refined_recompute = refined_recompute + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.hidden_dropout_prob = hidden_dropout_prob + self.compression_ratio = compression_ratio + self.skip_recompute_ops = dict() + self.quant_bits = quant_bits + self.num_key_value_heads = num_key_value_heads + self.submatrix_parallel = submatrix_parallel + self.submatrix_parallel_low_memory = submatrix_parallel_low_memory + self.use_sparse_head_and_loss_fn = use_sparse_head_and_loss_fn + self.use_tpsp_comm_overlap = use_tpsp_comm_overlap + self.offload_pp_data_chunk_size = offload_pp_data_chunk_size + self.use_fused_head_loss_fn = use_fused_head_loss_fn + self.use_recompute_resampler = use_recompute_resampler + self.resampler_fuse_rms_norm = resampler_fuse_rms_norm + self.token_balance_loss = token_balance_loss + self.token_balance_seqlen = token_balance_seqlen + self.rope_3d = rope_3d + self.freq_allocation = freq_allocation + self.decoderlayer_act_offload_settings = decoderlayer_act_offload_settings + self.loss_subbatch_seqlen = loss_subbatch_seqlen + self.gate_force_zero_padding_grad = gate_force_zero_padding_grad + + default_fp8_configs = { + "quant_scheme": "DelayedScaling", + "recipe": { + "format": "hybrid", + "calibrating": True, + "amax_history_len": 1024, + "amax_compute_algo": "max", + "fuse_wgrad_accumulation": False, + "quant_weight_at_first_microbatch": False, + }, + "layers": { + "attn_fc1_linear": True, + "attn_fc2_linear": True, + "mlp_fc1_linear": True, + "mlp_fc2_linear": True, + "attn_tp_fc1_linear": True, + "attn_tp_fc2_linear": True, + "mlp_tp_fc1_linear": True, + "mlp_tp_fc2_linear": True, + }, + "smooth_swiglu": False, + } + + def update_nested_dict(default_dict, update_dict): + for key, value in update_dict.items(): + if ( + isinstance(value, dict) + and key in default_dict + and isinstance(default_dict[key], dict) + ): + update_nested_dict(default_dict[key], value) + else: + default_dict[key] = value + + update_nested_dict(default_fp8_configs, fp8_configs) + self.fp8_configs = default_fp8_configs + self.use_fp8 = use_fp8 + self.expert_mlp_use_bias = expert_mlp_use_bias + self.use_fp8_mlp = use_fp8_mlp + default_fp8_mem_configs = { + "shared_expert": False, + "recompute_fwd_gate_up": False, + "dequant_input": False, + } + update_nested_dict(default_fp8_mem_configs, fp8_mem_configs) + self.fp8_mem_configs = default_fp8_mem_configs + default_fp8_fused_ops_configs = { + "stack_quant": False, + "swiglu_probs_bwd": False, + "split_group_gemm": True, + } + update_nested_dict(default_fp8_fused_ops_configs, fp8_fused_ops_configs) + self.fp8_fused_ops_configs = default_fp8_fused_ops_configs + self.drop_before_deepep = drop_before_deepep + self.deepep_drop_padding = deepep_drop_padding + self.disable_pipeline_warmup = disable_pipeline_warmup + self.skip_align_position_id = skip_align_position_id + self.moe_layer_feed_fake_token = moe_layer_feed_fake_token + + if self.sequence_parallel: + assert ( + self.using_dynamic_sequence_length or self.seqlen + ), "seqlen not provided in sequence-parallel when not using dygramic sequence length" + + assert ( + self.tensor_parallel_degree > 1 + ), f"senquence-parallel only works in mp, got mp={self.tensor_parallel_degree}" + + self.register_nonsaveable_keys("use_recompute") + self.register_nonsaveable_keys("recompute_use_reentrant") + self.register_nonsaveable_keys("refined_recompute") + self.register_nonsaveable_keys("use_recompute_attn") + self.register_nonsaveable_keys("use_recompute_lm_head") + self.register_nonsaveable_keys("use_recompute_mtp") + self.register_nonsaveable_keys("use_recompute_dnd") + self.register_nonsaveable_keys("use_recompute_loss_fn") + self.register_nonsaveable_keys("using_precision_check") + self.register_nonsaveable_keys("skip_recompute_ops") + + def __setattr__(self, name: str, value): + super().__setattr__(name, value) + if getattr(self, "use_recompute", False): + assert not getattr( + self, "use_recompute_attn", False + ), "cannot set `use_recompute_attn=True` when `use_recompute=True`" + + def register_nonsaveable_keys(self, keys): + + if hasattr(super(), "register_nonsaveable_keys"): + return super().register_nonsaveable_keys(keys) + elif hasattr(super(), "register_unsavable_keys"): + return super().register_unsavable_keys(keys) + else: + raise AttributeError( + "register_nonsaveable_keys not found in PretrainedConfig" + ) + + +class ErnieMoEConfig(ErnieConfig): + r""" + This is the configuration class to store the configuration of a [`~ErnieModel`]. It is used to instantiate an Ernie + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Ernie-7B. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Ernie model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`~ErnieModel`] or [`~TFErnieModel`]. + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings(`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + Example: + ```python + >>> from paddleformers.transformer import ErnieModel, ErnieConfig + + >>> # Initializing a Ernie ernie-7b style configuration + >>> configuration = ErnieConfig() + + >>> # Initializing a model from the ernie-7b style configuration + >>> model = ErnieModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "ernie" + attribute_map = { + "n_positions": "max_position_embeddings", + "n_embd": "hidden_size", + "n_layer": "num_hidden_layers", + "n_head": "num_attention_heads", + "n_inner": "intermediate_size", + "activation_function": "hidden_act", + } + pretrained_init_configuration = ERNIE_PRETRAINED_INIT_CONFIGURATION + + def __init__( + self, + moe_num_experts: Union[int, list] = 0, + use_fake_gate=False, + use_recompute_moe=False, + moe_capacity=(), + moe_layer_interval=2, + moe_layer_start_index: Union[int, list] = 0, + moe_layer_end_index: Union[int, list] = -1, + moe_aux_loss_lambda=1e-2, + moe_z_loss_lambda=1e-4, + moe_orthogonal_loss_lambda=1e-2, + moe_use_size_all2all=False, + sinkhorn_2gate=True, + sinkhorn_temp=3e-2, + global_aux_loss=False, + moe_dropout_prob=0.0, + moe_group="world", + moe_gate="top2", + moe_num_attn_experts=False, + moe_logging=False, + num_experts_per_tok: int = 8, + moe_intermediate_size: Union[int, list] = 0, + moe_num_shared_experts: int = 0, + moe_num_dense_experts: int = 0, + moe_dense_experts_token_type_id: int = 3, + moe_multimodal_dispatch_use_allgather: str = "", + moe_multimodal_paired_experts: bool = False, + moe_reverse_token_drop: bool = False, + moe_gate_act: str = "softmax", + moe_norm_gate_logits=True, + moe_use_hard_gate: bool = False, + moe_use_bpr: bool = False, + moe_fuse_experts: bool = False, + moe_all_to_all_dropout: float = 0.0, + moe_use_token_type_bias: bool = False, + moe_k=2, + moe_use_aux_free: bool = False, + moe_group_experts: bool = False, + moe_group_orthogonal_loss: bool = False, + moe_with_send_router_loss: bool = True, + enable_delay_scale_loss: bool = True, + num_acc_steps: int = None, + insert_empty_layer: list = None, + pp_no_recompute_layer: list = None, + multi_token_pred_depth: int = 0, + multi_token_pred_lambda: float = 0.3, + fuse_gate_detach_matmul: bool = False, + enable_mtp_magic_send: bool = False, + use_elastic_topk: bool = False, + use_deepep: bool = False, + use_elastic_expert_num: bool = False, + elastic_min_expert_num: int = 0, + all_expert_ratio: float = 1.0, + use_elastic_topk_for_mbs: bool = False, + elastic_min_topk: int = 1, + elastic_max_topk: int = None, + n_group: int = 0, + topk_group: int = 0, + scaling_factor: float = None, + aux_loss_type: str = "", + deepep_fine_grained: bool = False, + deepep_use_fused: bool = False, + deepep_tokens_per_subbatch: int = 0, + use_linear_residual_norm_recompute: bool = False, + use_rms_qkv_recompute: bool = False, + build_skip_comm_buffer: bool = False, + use_norm_gate_recompute: bool = False, + moe_state_dict_use_global_expert_id: bool = False, + enable_entropy_logging: bool = False, + use_fp8_fuse_node: bool = False, + use_combine_before_a2a: bool = False, + use_fp8_dispatch_a2a: bool = False, + use_ep_comm_overlap: bool = False, + **kwargs, + ): + """ + config + """ + if use_recompute_moe: + logger.warning("set `use_recompute_moe`=True, disabling `use_recompute`") + kwargs["use_recompute"] = False + super().__init__(**kwargs) + # moe + self.use_fake_gate = use_fake_gate + self.use_recompute_moe = use_recompute_moe + self.moe_num_experts = moe_num_experts + self.moe_capacity = moe_capacity + self.moe_aux_loss_lambda = moe_aux_loss_lambda + self.moe_z_loss_lambda = moe_z_loss_lambda + self.moe_orthogonal_loss_lambda = moe_orthogonal_loss_lambda + self.global_aux_loss = global_aux_loss + self.sinkhorn_2gate = sinkhorn_2gate + self.sinkhorn_temp = sinkhorn_temp + self.moe_layer_interval = moe_layer_interval + self.moe_dropout_prob = moe_dropout_prob + self.moe_group = moe_group + self.moe_gate = moe_gate + self.moe_num_attn_experts = moe_num_attn_experts + self.moe_use_size_all2all = moe_use_size_all2all + self.moe_logging = moe_logging + self.num_experts_per_tok = num_experts_per_tok + self.moe_num_shared_experts = moe_num_shared_experts + self.moe_num_dense_experts = moe_num_dense_experts + self.moe_dense_experts_token_type_id = moe_dense_experts_token_type_id + self.moe_intermediate_size = moe_intermediate_size + self.moe_reverse_token_drop = moe_reverse_token_drop + self.moe_use_hard_gate = moe_use_hard_gate + self.moe_fuse_experts = moe_fuse_experts + self.moe_k = moe_k + self.moe_all_to_all_dropout = moe_all_to_all_dropout + self.moe_use_token_type_bias = moe_use_token_type_bias + self.moe_use_bpr = moe_use_bpr + self.moe_group_experts = moe_group_experts + self.moe_group_orthogonal_loss = moe_group_orthogonal_loss + # optimize send without router loss + self.moe_with_send_router_loss = moe_with_send_router_loss + self.enable_delay_scale_loss = enable_delay_scale_loss + self.num_acc_steps = num_acc_steps + self.moe_layer_start_index = moe_layer_start_index + self.moe_layer_end_index = ( + self.num_hidden_layers - 1 + if moe_layer_end_index == -1 + else moe_layer_end_index + ) + self.moe_multimodal_dispatch_use_allgather = ( + moe_multimodal_dispatch_use_allgather + ) + self.moe_multimodal_paired_experts = moe_multimodal_paired_experts + self.moe_gate_act = moe_gate_act + self.moe_norm_gate_logits = moe_norm_gate_logits + self.moe_use_aux_free = moe_use_aux_free + self.fuse_gate_detach_matmul = fuse_gate_detach_matmul + if insert_empty_layer is not None: + assert isinstance( + insert_empty_layer, list + ), "insert_empty_layer should be a list" + else: + insert_empty_layer = [] + + # Overlap A2A communication with shared expert and auxiliary loss. + self.use_ep_comm_overlap = use_ep_comm_overlap + # Move the combine operation before A2A communication. + self.use_combine_before_a2a = use_combine_before_a2a + # Use FP8 for dispatch communication. + self.use_fp8_dispatch_a2a = use_fp8_dispatch_a2a + + # Multi-Token Prediction (MTP) + self.multi_token_pred_depth = multi_token_pred_depth + self.multi_token_pred_lambda = multi_token_pred_lambda + self.enable_mtp_magic_send = enable_mtp_magic_send + + self.insert_empty_layer = insert_empty_layer + + # elastic + self.use_elastic_topk = use_elastic_topk + self.use_elastic_expert_num = use_elastic_expert_num + self.elastic_min_expert_num = elastic_min_expert_num + self.all_expert_ratio = all_expert_ratio + self.use_elastic_topk_for_mbs = use_elastic_topk_for_mbs + self.elastic_min_topk = elastic_min_topk + if elastic_max_topk is None: + self.elastic_max_topk = self.moe_k * 2 - 1 + + # Using fusion expert node in moe layer. + self.use_fp8_fuse_node = use_fp8_fuse_node + + # Perform MoE computation at expert granularity. + self.deepep_fine_grained = deepep_fine_grained + # Requires deepep_fine_grained to be enabled; further disperses token + # granularity within experts to compute subbatches. + self.deepep_tokens_per_subbatch = deepep_tokens_per_subbatch + # Fuse combine and scatter operations when using BF16 for expert computation. + self.deepep_use_fused = deepep_use_fused + + assert not ( + self.use_combine_before_a2a and self.use_deepep + ), "combine_before_a2a is not supported for deepep now." + + assert not ( + self.use_fp8_dispatch_a2a and not self.use_fp8_fuse_node + ), "fp8_dispatch_a2a must be used with use_fp8_fuse_node." + + assert not ( + self.use_fp8_dispatch_a2a and self.use_ep_comm_overlap + ), "fp8_dispatch_a2a connot be used with use_ep_comm_overlap." + + if self.deepep_tokens_per_subbatch: + assert ( + self.deepep_fine_grained + ), "deepep_fine_grained must be enabled when deepep_tokens_per_subbatch is set." + + # node limit routing + self.n_group = n_group + self.topk_group = topk_group + + # router scaling_factor + self.scaling_factor = scaling_factor + + self.build_skip_comm_buffer = build_skip_comm_buffer + + # router loss type + assert aux_loss_type in ["", "default", "seq_aux_loss", "switch_aux_loss"] + self.aux_loss_type = aux_loss_type + + self.use_deepep = use_deepep + if self.moe_multimodal_paired_experts and isinstance( + self.moe_num_experts, (tuple, list) + ): + logger.warning( + "moe_num_experts must be one element when using paired experts" + ) + self.moe_num_experts = self.moe_num_experts[0] + + if pp_no_recompute_layer is not None: + assert isinstance( + insert_empty_layer, list + ), "pp_no_recompute_layer should be a list" + + self.pp_no_recompute_layer = pp_no_recompute_layer + self.register_nonsaveable_keys("moe_group") + self.register_nonsaveable_keys("pp_no_recompute_layer") + + if ( + self.moe_group in ["dp", "data"] + and self.moe_multimodal_dispatch_use_allgather + ): + assert ( + self.moe_num_shared_experts == 0 + ), "shared experts are not supported when using dp moe and moe_allgather_layer" + assert ( + self.moe_num_dense_experts == 0 + ), "dense experts are not supported when using dp moe and moe_allgather_layer" + + self.use_linear_residual_norm_recompute = use_linear_residual_norm_recompute + self.use_rms_qkv_recompute = use_rms_qkv_recompute + self.use_norm_gate_recompute = use_norm_gate_recompute + self.moe_state_dict_use_global_expert_id = moe_state_dict_use_global_expert_id + self.enable_entropy_logging = enable_entropy_logging + + @property + def multimodel_experts(self) -> bool: + + return ( + isinstance(self.moe_num_experts, (tuple, list)) + and len(self.moe_num_experts) > 1 + ) + + @property + def use_moe(self) -> bool: + """_summary_ + + Returns: + bool: _description_ + """ + return ( + sum(self.moe_num_experts) > 0 + if self.multimodel_experts + else self.moe_num_experts > 0 + ) + + def __setattr__(self, name: str, value): + super().__setattr__(name, value) + if getattr(self, "use_recompute", False): + assert not getattr( + self, "use_recompute_moe", False + ), "cannot set `use_recompute_moe=True` when `use_recompute=True`" + + def to_json_string(self, use_diff: bool = True) -> str: + + if use_diff is True: + config_dict = self.to_diff_dict() + else: + config_dict = self.to_dict() + + def _serializer(obj): + if isinstance(obj, paddle.distributed.communication.group.Group): + return repr(obj) + raise TypeError(f"Type {type(obj)} is not serializable") + + return ( + json.dumps( + config_dict, + indent=2, + sort_keys=True, + ensure_ascii=False, + default=_serializer, + ) + + "\n" + ) diff --git a/examples/pre-training/models/ernie/modeling_auto.py b/examples/pre-training/models/ernie/modeling_auto.py new file mode 100644 index 00000000..ef145aab --- /dev/null +++ b/examples/pre-training/models/ernie/modeling_auto.py @@ -0,0 +1,2851 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Paddle Ernie model""" +import math +import functools +from functools import partial +import logging +from typing import Optional, Tuple +import contextlib +import inspect + + +from copy import deepcopy +from dataclasses import dataclass +import numpy as np +import paddle +import paddle.distributed as dist +import paddle.nn.functional as F +from paddle import nn +from paddle.distributed import fleet +from paddle.distributed.fleet.utils import recompute +from paddle.distributed.fleet.layers.mpu.random import get_rng_state_tracker +from paddle.incubate.nn.memory_efficient_attention import ( + memory_efficient_attention, + BlockDiagonalCausalMask, +) +from paddle.distributed import in_auto_parallel_align_mode + + +from models.moe.top2_gate_auto import Top2Gate, TopKGateFusedAuto + + +from paddleformers.transformers.conversion_utils import ( + StateDictNameMapping, + init_name_mappings, +) + +from paddleformers.transformers.model_outputs import ( + BaseModelOutputWithPastAndCrossAttentions as _BaseModelOutput, +) +from paddleformers.transformers.model_outputs import CausalLMOutputWithCrossAttentions + +from paddleformers.transformers.model_utils import PretrainedModel, register_base_model + +from models.sequence_parallel_utils_auto import ( + sequence_parallel_sparse_mask_labels, +) +from models.moe.moe_layer_auto import ( + MOELayerAuto, +) +from models.ernie.configuration_auto import ErnieMoEConfig +from models.moe.moe_utils_auto import get_mesh + + +@dataclass +class BaseModelOutputWithPastAndCrossAttentions(_BaseModelOutput): + + router_loss: Optional[paddle.Tensor] = None + gate_logits: Optional[Tuple[paddle.Tensor]] = None + + +@dataclass +class CausalLMOutputWithCrossAttentionsAuto(CausalLMOutputWithCrossAttentions): + + router_loss: Optional[paddle.Tensor] = None + + +logger = logging.getLogger(__name__) + + +try: + from paddle.nn.functional.flash_attention import flash_attention + + logger.warning( + "Use flash attention in scaled-dot-product. Attention mask is deprecated" + ) +except (ImportError, ModuleNotFoundError): + flash_attention = None + +try: + from paddle.nn.functional.flash_attention import flash_attention_with_mask +except (ImportError, ModuleNotFoundError): + try: + from paddle.nn.functional.flash_attention import ( + scaled_dot_product_attention as flash_attention_with_mask, + ) + except (ImportError, ModuleNotFoundError): + logger.warning( + "flash_attention_with_mask not found. Use FleetY8.2 SFT instead." + ) + flash_attention_with_mask = None + +try: + from paddle.nn.functional.flash_attention import flash_attention_with_sparse_mask +except (ImportError, ModuleNotFoundError): + logger.warning("flash_attention_with_sparse_mask not found. Use FleetY8.9 instead.") + flash_attention_with_sparse_mask = None + +try: + from to_block_diag_causal_mask import to_block_diag_causal_mask +except (ImportError, ModuleNotFoundError): + logger.warning("to_block_diag_causal_mask not found. Use FleetY8.2 SFT instead.") + to_block_diag_causal_mask = None + +try: + from fast_ln import fast_ln +except ImportError: + fast_ln = None + + +try: + from paddle.incubate.nn.functional import ( + fused_rotary_position_embedding as fused_rope, + ) +except (ImportError, ModuleNotFoundError): + logger.warning("fused_rotary_position_embedding not found") + fused_rope = None + +try: + from paddle.incubate.nn.functional import swiglu as fused_swiglu +except (ImportError, ModuleNotFoundError): + fused_swiglu = None + + +ERNIE_PRETRAINED_MODEL_ARCHIVE_LIST = [] + +__all__ = [ + "ErnieModelAuto", + "ErniePretrainedModelAuto", + "ErnieForCausalLMAuto", +] + + +gate_class = dict( + top2=Top2Gate, + top2_fused=TopKGateFusedAuto, +) + + +def subbatch(f, arg_idx, axis, bs, out_idx, use_recompute=False, same_arg_idx={}): + @functools.wraps(f) + def wrapper(*args, **kwargs): + + assert len(arg_idx) == len( + axis + ), "Number of batching args and number of batching dims should match." + + inps = [args[i] for i in arg_idx] + axis_width = [inp.shape[d] for inp, d in zip(inps, axis)] + assert len(set(axis_width)) == 1, "Batch sizes should be kept equal." + + inp_axis = {inp: d for inp, d in zip(inps, axis)} + + axis_width = axis_width[0] + if axis_width < bs: + return f(*args, **kwargs) + + outs = [] + for slice_at in np.arange(0, axis_width, bs): + _args = [] + for i, inp in enumerate(args): + if i in same_arg_idx: + assert ( + i > same_arg_idx[i] + ), f"expect i > same_arg_idx[i], but got i: {i} and same_arg_idx[i]: {same_arg_idx[i]}" + _args.append(_args[same_arg_idx[i]]) + elif i in arg_idx: + inp = inp.slice( + [inp_axis[inp]], + [slice_at], + [min(inp.shape[inp_axis[inp]], slice_at + bs)], + ) + _args.append(inp) + else: + _args.append(inp) + if use_recompute: + out = paddle.distributed.fleet.utils.recompute(f, *_args, **kwargs) + else: + out = f(*_args, **kwargs) + outs.append(out) + + return paddle.concat(outs, out_idx) + + return wrapper + + +class FusedDropoutImpl(nn.Layer): + + def __init__(self, prob, mode): + super().__init__() + self.prob = prob + self.mode = mode + + self.dropout = nn.Dropout(p=prob, mode=mode) + + def forward(self, x, y): + if self.prob > 0: + x = self.dropout(x) + output = x + y + + return output + + +def is_pp_enable(): + + mesh = fleet.auto.get_mesh() + return "pp" in mesh.dim_names + + +def global_mesh_starts_with_pp(): + + mesh = fleet.auto.get_mesh() + if is_pp_enable(): + return mesh.get_mesh_with_dim("pp") + else: + return mesh + + +def is_fleety_func(): + """ + Check whether it is PaddlePaddle FleetY version. + """ + if flash_attention_with_sparse_mask is None: + return True + + args = inspect.getfullargspec(flash_attention_with_sparse_mask).args + return "causal" in args + + +IS_FLEETY = is_fleety_func() + + +def get_triangle_upper_mask(x, mask=None): + + if mask is not None: + return mask + shape = x.shape + shape[1] = 1 + mask = paddle.full(shape, -np.inf, dtype=x.dtype) + mask.stop_gradient = True + mask = paddle.triu(mask, diagonal=1) + mask.stop_gradient = True + return mask + + +def naive_fuse_split_tp( + weight, + tensor_parallel_degree, + tensor_parallel_rank=None, + is_column=True, + fuse_tensor_parts=2, +): + + logging.info(f"spliting fused-ffn: {weight.shape}") + axis = -1 if is_column else 0 + splited = np.split(weight, fuse_tensor_parts * tensor_parallel_degree, axis=axis) + return np.concatenate( + splited[tensor_parallel_rank::tensor_parallel_degree], axis=axis + ) + + +def parallel_matmul( + x, + y, + bias=None, + transpose_y=False, + tensor_parallel_degree=1, + tensor_parallel_output=True, +): + + if transpose_y: + logits = paddle.matmul(x, y, transpose_y=True) + if bias is not None: + logits += bias + else: + logits = F.linear(x, y, bias) + + if tensor_parallel_degree > 1 and not tensor_parallel_output: + logits = dist.reshard(logits, get_mesh(-1), [dist.Shard(0), dist.Replicate()]) + + return logits + + +def calc_lm_head_logits( + config, + hidden_states, + weight, + bias, + sparse_label_idx=None, + tensor_parallel_output=None, +): + """the core function to calc lm head""" + if config.sequence_parallel: + + assert ( + not config.use_sparse_head_and_loss_fn + ), "use_sparse_head_and_loss_fn is not supported now." + + hcg = paddle.distributed.fleet.get_hybrid_communicate_group() + dp_rank = hcg.get_data_parallel_rank() + sharding_rank = hcg.get_sharding_parallel_rank() + if dp_rank <= 1 and sharding_rank <= 1: + hidden_states = dist.reshard( + hidden_states, + get_mesh(-1), + [dist.Replicate(), dist.Replicate()], + ) + else: + hidden_states = dist.reshard( + hidden_states, + get_mesh(-1), + [dist.Shard(1), dist.Replicate()], + ) + # [S, B, H] to [B, S, H] + hidden_states = paddle.transpose(hidden_states, [1, 0, 2]) + if not config.using_dynamic_sequence_length: + hidden_states = hidden_states.reshape( + [-1, config.seqlen, hidden_states.shape[-1]] + ) + else: + assert ( + config.micro_batch_size + ), "micro_batch_size should be set when using dygramic sequence length." + hidden_states = hidden_states.reshape( + [config.micro_batch_size, -1, hidden_states.shape[-1]] + ) + if tensor_parallel_output is None: + tensor_parallel_output = config.tensor_parallel_output + logits = parallel_matmul( + hidden_states, + weight, + bias=bias, + transpose_y=config.tie_word_embeddings, + tensor_parallel_degree=config.tensor_parallel_degree, + tensor_parallel_output=tensor_parallel_output, + ) + + return logits + + +def finfo(dtype: paddle.dtype = None): + + if dtype is None: + dtype = paddle.get_default_dtype() + + if dtype == paddle.bfloat16: + + class BFloatFInfo: + min = -3.3895313892515355e38 + + return BFloatFInfo + if dtype == paddle.float32: + return np.finfo(np.float32) + if dtype == paddle.float16: + return np.finfo(np.float16) + if dtype == paddle.float64: + return np.finfo(np.float64) + + +def masked_fill(x, mask, value): + + y = paddle.full(x.shape, value, x.dtype) + return paddle.where(mask, y, x) + + +def mem_eff_attn( + query, key, value, pack_offset, drop_prob=0.0, dtype=paddle.bfloat16, training=True +): + + pack_offset = pack_offset.numpy() + shape = pack_offset.shape + assert len(shape) == 2, len(shape) + assert shape[0] == 1, shape[0] + n = pack_offset.size + pack_offset = pack_offset.flatten() + seqlens = [] + assert pack_offset[0] == 0, pack_offset[0] + for i in range(1, n): + if pack_offset[i] < 0: + break + cur = pack_offset[i] - pack_offset[i - 1] + assert cur > 0 + seqlens.append(cur) + + assert drop_prob == 0.0, drop_prob + assert dtype == paddle.bfloat16, dtype + + def cast(x): + return x.astype(dtype) if x.dtype != dtype else x + + if len(seqlens) == 1: + out, _ = flash_attention( + query, key, value, drop_prob, causal=True, training=training + ) + else: + mask = BlockDiagonalCausalMask.from_seqlens(seqlens) + out = memory_efficient_attention( + cast(query), + cast(key), + cast(value), + attn_bias=mask, + p=drop_prob, + training=training, + ) + return out + + +def inbatch_pack_offset_to_attn_mask_start_row_indices(inbatch_pack_offset): + inbatch_pack_offset = inbatch_pack_offset.numpy() + attn_mask_row_start_indices = [] + min_start_row = np.inf + for bidx in range(inbatch_pack_offset.shape[0]): + item = inbatch_pack_offset[bidx] + cumsum_item = item[item != -1] + record_lens = cumsum_item[1:] - cumsum_item[0:-1] + min_start_row = min(cumsum_item[1], min_start_row) + row_start_indices = np.repeat(cumsum_item[1:], record_lens) + attn_mask_row_start_indices.append(row_start_indices[None, None, ...]) + attn_mask_row_start_indices = np.concatenate(attn_mask_row_start_indices, axis=0) + return paddle.to_tensor(attn_mask_row_start_indices, dtype=paddle.int32), int( + min_start_row + ) + + +def scaled_dot_product_attention( + query_states, + key_states, + value_states, + attention_mask, + output_attentions, + config, + is_causal=True, + rr_flash_attn=None, + inbatch_pack_offset=None, + training=True, +): + + bsz, q_len, num_heads, head_dim = query_states.shape + _, kv_seq_len, num_key_value_heads, _ = value_states.shape + + can_use_fa = config.use_flash_attn and flash_attention is not None + can_use_fa_sparse_mask = ( + config.use_mem_eff_attn + and inbatch_pack_offset is not None + and flash_attention_with_sparse_mask is not None + ) + + if not can_use_fa and not can_use_fa_sparse_mask: + if query_states.shape[-2] != key_states.shape[-2]: + key_states = key_states.repeat_interleave( + num_heads // num_key_value_heads, axis=-2 + ) + if query_states.shape[-2] != value_states.shape[-2]: + value_states = value_states.repeat_interleave( + num_heads // num_key_value_heads, axis=-2 + ) + + if can_use_fa: + if rr_flash_attn is not None: + attn_output, attn_weights = rr_flash_attn( + query_states, + key_states, + value_states, + dropout=config.attention_probs_dropout_prob, + causal=is_causal and query_states.shape[1] != 1, + return_softmax=output_attentions, + ) + else: + attn_output, attn_weights = flash_attention( + query_states, + key_states, + value_states, + dropout=config.attention_probs_dropout_prob, + causal=is_causal and query_states.shape[1] != 1, + return_softmax=output_attentions, + ) + + attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads]) + return attn_output, attn_weights + elif config.use_mem_eff_attn and inbatch_pack_offset is not None: + assert ( + not output_attentions + ), "output_attentions should be False when use_mem_eff_attn=True" + if config.use_flash_attn_with_mask: + if flash_attention_with_sparse_mask is not None: + causal_mask_indices, attn_mask_min_start_row = ( + inbatch_pack_offset_to_attn_mask_start_row_indices( + inbatch_pack_offset + ) + ) + if IS_FLEETY: + kwargs = { + "causal": True, + "dropout": config.attention_probs_dropout_prob, + } + else: + kwargs = { + "is_causal": True, + "dropout_p": config.attention_probs_dropout_prob, + } + attn_output = flash_attention_with_sparse_mask( + query_states.astype(value_states.dtype), + key_states.astype(value_states.dtype), + value_states.astype(value_states.dtype), + attn_mask_start_row_indices=causal_mask_indices, + attn_mask_start_row=attn_mask_min_start_row, + **kwargs, + ) + else: + attn_mask = to_block_diag_causal_mask( + inbatch_pack_offset, q_len, float("-inf"), "bfloat16" + ) + attn_output = flash_attention_with_mask( + query_states, + key_states, + value_states, + attn_mask, + config.attention_probs_dropout_prob, + ) + else: + attn_output = mem_eff_attn( + query_states, + key_states, + value_states, + inbatch_pack_offset, + drop_prob=config.attention_probs_dropout_prob, + ) + attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads]) + return attn_output, None + else: + + query_states = paddle.transpose(query_states, [0, 2, 1, 3]) / math.sqrt( + head_dim + ) + key_states = paddle.transpose(key_states, [0, 2, 1, 3]) + value_states = paddle.transpose(value_states, [0, 2, 1, 3]) + + attn_weights = paddle.matmul(query_states, key_states.transpose([0, 1, 3, 2])) + + if attn_weights.shape != [bsz, num_heads, q_len, kv_seq_len]: + raise ValueError( + f"Attention weights should be of shape {(bsz, num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.shape}" + ) + + if attention_mask is None: + attention_mask = get_triangle_upper_mask(attn_weights) + + attention_mask = attention_mask.reshape([bsz, 1, q_len, kv_seq_len]) + if attention_mask.shape != [bsz, 1, q_len, kv_seq_len]: + raise ValueError( + f"Attention mask should be of shape {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.shape}" + ) + if training: + attn_weights = attention_mask + attn_weights + attn_weights = paddle.maximum( + attn_weights, + paddle.to_tensor( + float(finfo(query_states.dtype).min), dtype=query_states.dtype + ), + ) + + if paddle.in_dynamic_mode(): + with paddle.amp.auto_cast(False): + attn_weights = F.softmax( + attn_weights, axis=-1, dtype="float32" + ).astype(query_states.dtype) + else: + attn_weights = F.softmax(attn_weights, axis=-1, dtype="float32").astype( + query_states.dtype + ) + else: + attn_weights = attn_weights.cast(paddle.float32) + attention_mask = attention_mask.cast(paddle.float32) + attn_weights = attn_weights.add_(attention_mask) + attn_weights = F.softmax_(attn_weights, axis=-1).astype(query_states.dtype) + + if config.attention_probs_dropout_prob > 0.0: + if config.tensor_parallel_degree > 1: + with get_rng_state_tracker().rng_state("local_seed"): + attn_weights = F.dropout( + attn_weights, + config.attention_probs_dropout_prob, + training=training, + mode="upscale_in_train", + ) + else: + attn_weights = F.dropout( + attn_weights, + config.attention_probs_dropout_prob, + training=training, + mode="upscale_in_train", + ) + + attn_output = paddle.matmul(attn_weights, value_states) + attn_output = attn_output.transpose([0, 2, 1, 3]) + attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads]) + if output_attentions: + return attn_output, attn_weights + return attn_output, None + + +def _make_causal_mask(input_ids_shape, past_key_values_length, dtype): + batch_size, target_length = input_ids_shape + + mask = paddle.full((target_length, target_length), float(finfo(dtype).min)) + + mask_cond = paddle.arange(mask.shape[-1]) + mask = masked_fill( + mask, mask_cond < (mask_cond + 1).reshape([mask.shape[-1], 1]), 0 + ) + + if past_key_values_length > 0: + mask = paddle.concat( + [paddle.zeros([target_length, past_key_values_length]), mask], axis=-1 + ) + + return mask[None, None, :, :].expand( + [batch_size, 1, target_length, target_length + past_key_values_length] + ) + + +def _expand_mask(mask, dtype, tgt_length): + if mask.ndim == 4: + expanded_mask = mask + elif mask.ndim == 3: + expanded_mask = mask[:, None, :, :] + else: + batch_size, src_length = mask.shape[0], mask.shape[-1] + tgt_length = tgt_length if tgt_length is not None else src_length + + expanded_mask = mask[:, None, None, :].expand( + [batch_size, 1, tgt_length, src_length] + ) + + inverted_mask = 1.0 - expanded_mask + return masked_fill( + inverted_mask, inverted_mask.cast("bool"), float(finfo(dtype).min) + ) + + +def slice_experts(experts, moe_world_size): + moe_num_experts_per_device = len(experts) // moe_world_size + experts_per_device = [[] for _ in range(moe_world_size)] + + for i, expert in enumerate(experts): + ep_group_id = i // moe_num_experts_per_device + experts_per_device[ep_group_id].append(expert) + + lm_experts = nn.LayerList([]) + for experts_list in experts_per_device: + lm_experts.extend(experts_list[: moe_num_experts_per_device // 2]) + return lm_experts + + +def get_gate( + config: ErnieMoEConfig, + expert: Tuple[Tuple[int, nn.Layer]], + layer_idx: int, + ipp: int = 0, +) -> Tuple[nn.Layer, nn.LayerList]: + + moe_num_experts = config.moe_num_experts + assert ( + moe_num_experts >= config.moe_world_size + ), f"expert moe_num_experts={moe_num_experts} >= moe_world_size={config.moe_world_size}" + assert ( + moe_num_experts % config.moe_world_size == 0 + ), f"expert moe_num_experts={moe_num_experts} % moe_world_size={config.moe_world_size} == 0" + moe_num_experts_per_device = moe_num_experts // config.moe_world_size + experts = nn.LayerList([]) + for expert_id, (experts_num, fc) in enumerate(expert): + assert experts_num % config.moe_world_size == 0 + experts_to_append = [] + if not hasattr(fc, "__len__"): + experts_to_append.append(fc) + if expert_id == 1: + with paddle.utils.unique_name.guard("_mm_deepcopy"): + for _ in range(experts_num - 1): + experts_to_append.append(deepcopy(fc)) + else: + for _ in range(experts_num - 1): + experts_to_append.append(deepcopy(fc)) + else: + experts_to_append = fc + for ex in experts_to_append: + for p in ex.parameters(): + p.expert_type = f"expert_type_{expert_id}" + experts.extend(experts_to_append) + + logger.info( + f"using moe-world-size: {config.moe_world_size} " + f"expert-per-device: {moe_num_experts_per_device} " + ) + if config.moe_use_hard_gate and moe_num_experts <= 2: + gate = None + logger.info("MOE-GATE:-hard-gate") + else: + logger.info(f"MOE-GATE:-{config.moe_gate}") + gate = gate_class[config.moe_gate.lower()]( + config, layer_idx=layer_idx, group=config.moe_group, ipp=ipp + ) + + lm_gate, lm_experts = None, None + logger.info(f"LM-experts-{lm_experts} -- experts-{experts}") + + index = 0 if config.moe_group == "dp" else 1 + ep_sub_meshes = dist.auto_parallel.api.split_mesh(get_mesh(ipp), index) + + for i, expert in enumerate(experts): + ep_group_id = i // moe_num_experts_per_device + if isinstance(expert, (ErnieMoeMLPFused, ErnieMoeMLP)): + experts[i].redistribute_expert( + ep_sub_meshes[ep_group_id], [dist.Replicate(), dist.Replicate()] + ) + experts[i].ep_group_id = ep_group_id + + return gate, experts, lm_gate, lm_experts + + +def _parse_moe_group(moe_group: str): + moe_group = moe_group.lower() + assert moe_group in { + "dp", + "mp", + "none", + }, f"moe-group not supported, got: {moe_group}" + logger.info(f"using moe-group: {moe_group}") + + return moe_group + + +class RMSNorm(nn.Layer): + + def __init__(self, config, ipp=0): + super().__init__() + self.hidden_size = config.hidden_size + self.weight = paddle.create_parameter( + shape=[self.hidden_size], + dtype=paddle.get_default_dtype(), + default_initializer=nn.initializer.Constant(1.0), + ) + self.variance_epsilon = config.rms_norm_eps + self.config = config + + def forward(self, hidden_states): + + if self.config.fuse_rms_norm: + return paddle.incubate.nn.functional.fused_rms_norm_ext( + hidden_states, self.weight, self.variance_epsilon + )[0] + if paddle.in_dynamic_mode(): + with paddle.amp.auto_cast(False): + variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True) + hidden_states = ( + paddle.rsqrt(variance + self.variance_epsilon) * hidden_states + ) + else: + variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True) + hidden_states = ( + paddle.rsqrt(variance + self.variance_epsilon) * hidden_states + ) + + if self.weight.dtype in [paddle.float16, paddle.bfloat16]: + hidden_states = paddle.cast(hidden_states, self.weight.dtype) + return hidden_states * self.weight + + +class LayerNorm(nn.LayerNorm): + + def __init__(self, config, ipp=0): + super().__init__(config.hidden_size, epsilon=config.rms_norm_eps) + + self.use_fast_ln = config.use_fast_ln + if self.use_fast_ln: + assert fast_ln is not None + self.ipp = ipp + if config.pipeline_parallel_degree > 1: + self.weight = dist.shard_tensor( + self.weight, get_mesh(self.ipp), [dist.Replicate(), dist.Replicate()] + ) + self.bias = dist.shard_tensor( + self.bias, get_mesh(self.ipp), [dist.Replicate(), dist.Replicate()] + ) + + def forward(self, hidden_states): + if self.use_fast_ln: + return fast_ln(hidden_states, self.weight, self.bias, self._epsilon)[0] + else: + return super().forward(hidden_states) + + +class RotaryEmbedding(nn.Layer): + + def __init__(self, dim, max_position_embeddings=4096, base=10000): + + super().__init__() + self.base = base + self.max_position_embeddings = max_position_embeddings + inv_freq = 1.0 / ( + base ** (paddle.cast(paddle.arange(0, dim, 2), dtype="float32") / dim) + ) + t = paddle.arange(max_position_embeddings, dtype="float32") + freqs = paddle.einsum("i,j->ij", t, inv_freq.cast("float32")) + emb = paddle.concat([freqs, freqs], axis=-1) + + self.cos_cached = emb.cos() + self.sin_cached = emb.sin() + + self._cast_to_low_precision = False + self._cast_to_low_precison = False + + def forward(self, x, seq_len=None): + + return ( + self.cos_cached[:seq_len, :], + self.sin_cached[:seq_len, :], + ) + + @classmethod + def rotate_half(cls, x): + + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return paddle.concat([-x2, x1], axis=-1) + + @classmethod + def apply_rotary_pos_emb(cls, q, k, cos, sin, offset: int = 0, position_ids=None): + if position_ids is not None: + assert offset == 0, offset + cos = F.embedding(position_ids, cos) + sin = F.embedding(position_ids, sin) + else: + cos = cos.unsqueeze(0) + sin = sin.unsqueeze(0) + cos = cos[:, offset : q.shape[1] + offset, None, :] + sin = sin[:, offset : q.shape[1] + offset, None, :] + + q_embed = paddle.add( + paddle.multiply(q, cos), paddle.multiply(cls.rotate_half(q), sin) + ) + k_embed = paddle.add( + paddle.multiply(k, cos), paddle.multiply(cls.rotate_half(k), sin) + ) + q_embed = q_embed.astype(q.dtype) + k_embed = k_embed.astype(k.dtype) + return q_embed, k_embed + + +class RopeEmbeddingLegacy(nn.Layer): + + def __init__(self, head_dim, compression_ratio=1.0, base=10000): + super().__init__() + self.head_dim = head_dim + self.compression_ratio = compression_ratio + self.base = base + + def forward(self, seq_length, position_ids=None): + + indices = paddle.arange(0, self.head_dim, 2, dtype="float32") + indices = 1 / self.base ** (indices / self.head_dim) + if position_ids is None: + position_ids = paddle.arange(0, seq_length, 1, dtype="float32").unsqueeze(1) + position_ids = position_ids / self.compression_ratio + sinusoid_inp = position_ids * indices.unsqueeze(0) + else: + position_ids = position_ids / self.compression_ratio + seq_length = position_ids.shape[-1] + sinusoid_inp = position_ids.unsqueeze(-1).astype( + "float32" + ) * indices.unsqueeze(0) + pos_emb = paddle.concat( + [paddle.sin(sinusoid_inp), paddle.cos(sinusoid_inp)], axis=-1 + ) + pos_emb = paddle.reshape(pos_emb, (-1, 1, seq_length, self.head_dim)) + pos_emb.stop_gradient = True + return pos_emb + + def apply_rotary(self, rp, q, k): + + sin, cos = paddle.chunk(rp, 2, axis=-1) + sin_pos = paddle.reshape(paddle.stack([sin, sin], axis=-1), rp.shape) + cos_pos = paddle.reshape(paddle.stack([cos, cos], axis=-1), rp.shape) + rotate_half_q = paddle.reshape( + paddle.stack([-q[:, :, :, 1::2], q[:, :, :, 0::2]], axis=-1), + paddle.shape(q), + ) + query = paddle.add( + paddle.multiply(q.astype("float32"), cos_pos), + paddle.multiply(rotate_half_q.astype("float32"), sin_pos), + ) + rotate_half_k = paddle.reshape( + paddle.stack([-k[:, :, :, 1::2], k[:, :, :, 0::2]], axis=-1), + paddle.shape(k), + ) + key = paddle.add( + paddle.multiply(k.astype("float32"), cos_pos), + paddle.multiply(rotate_half_k.astype("float32"), sin_pos), + ) + return query, key + + def forward_single(self, position_ids): + + batch_size, seq_length = position_ids.shape[:2] + rope_emb = paddle.zeros( + (2, batch_size, seq_length, 1, self.head_dim), dtype="float32" + ) + inv_freq = self.base ** ( + -paddle.arange(0, self.head_dim, 2, dtype="float32") / self.head_dim + ) + position_ids = position_ids.cast("float32") + position_ids = position_ids / self.compression_ratio + freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq) + emb = paddle.stack([freqs, freqs], axis=-1).reshape( + (batch_size, seq_length, self.head_dim) + ) + emb = paddle.unsqueeze(emb, 2) + + rope_emb[0] = paddle.cos(emb) + rope_emb[1] = paddle.sin(emb) + return rope_emb + + @staticmethod + def apply_rotary_single(x, rope_emb): + + rotate_half_x = paddle.reshape( + paddle.stack([-x[:, :, :, 1::2], x[:, :, :, 0::2]], axis=-1), + paddle.shape(x), + ) + return x * rope_emb[0] + rotate_half_x * rope_emb[1] + + +class ErnieLinear(nn.Layer): + + def __init__( + self, + in_features, + out_features, + weight_attr=None, + bias_attr=None, + name=None, + ipp=0, + ): + super(ErnieLinear, self).__init__() + self._dtype = self._helper.get_default_dtype() + self._weight_attr = weight_attr + self._bias_attr = bias_attr + self.weight = self.create_parameter( + shape=[in_features, out_features], + attr=self._weight_attr, + dtype=self._dtype, + is_bias=False, + ) + self.bias = self.create_parameter( + shape=[out_features], + attr=self._bias_attr, + dtype=self._dtype, + is_bias=True, + ) + self.name = name + self.ipp = ipp + + def forward(self, input): + + out = F.linear(x=input, weight=self.weight, bias=None, name=self.name) + out = dist.reshard( + out, + get_mesh(self.ipp), + [dist.Shard(1), dist.Shard(0)], + ) + if self.bias: + out += self.bias + return out + + +class ErnieMLP(nn.Layer): + + def __init__(self, config, ipp=None, do_shard_tensor=True): + super().__init__() + self.config = config + self.ipp = ipp + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + + LinearFN = nn.Linear + self.gate_proj = LinearFN( + self.hidden_size, self.intermediate_size, bias_attr=config.use_bias + ) + self.up_proj = LinearFN( + self.hidden_size, self.intermediate_size, bias_attr=config.use_bias + ) + + if config.sequence_parallel: + self.down_proj = ErnieLinear( + self.intermediate_size, + self.hidden_size, + bias_attr=config.use_bias, + ipp=self.ipp, + ) + else: + self.down_proj = LinearFN( + self.intermediate_size, self.hidden_size, bias_attr=config.use_bias + ) + + if do_shard_tensor and ( + self.config.tensor_parallel_degree > 1 + or self.config.pipeline_parallel_degree > 1 + ): + self.gate_proj.weight = dist.shard_tensor( + self.gate_proj.weight, + get_mesh(self.ipp), + [dist.Replicate(), dist.Shard(1)], + ) + self.up_proj.weight = dist.shard_tensor( + self.up_proj.weight, + get_mesh(self.ipp), + [dist.Replicate(), dist.Shard(1)], + ) + if config.use_bias: + self.gate_proj.bias = dist.shard_tensor( + self.gate_proj.bias, + get_mesh(self.ipp), + [dist.Replicate(), dist.Shard(0)], + ) + self.up_proj.bias = dist.shard_tensor( + self.up_proj.bias, + get_mesh(self.ipp), + [dist.Replicate(), dist.Shard(0)], + ) + self.down_proj.weight = dist.shard_tensor( + self.down_proj.weight, + get_mesh(self.ipp), + [dist.Replicate(), dist.Shard(0)], + ) + if config.use_bias: + self.down_proj.bias = dist.shard_tensor( + self.down_proj.bias, + get_mesh(self.ipp), + [dist.Replicate(), dist.Replicate()], + ) + + self.fuse_swiglu = config.fuse_swiglu + if self.fuse_swiglu: + assert fused_swiglu is not None, "fused_swiglu operator is not found." + + def forward(self, x): + + if self.fuse_swiglu: + x = fused_swiglu(self.gate_proj(x), self.up_proj(x)) + else: + x = F.silu(self.gate_proj(x)) * self.up_proj(x) + return self.down_proj(x) + + +class ErnieAttentionAuto(nn.Layer): + + def __init__(self, config, ipp: Optional[int] = None): + super().__init__() + self.ipp = ipp + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.head_dim = self.hidden_size // self.num_heads + self.use_recompute_attn = config.use_recompute_attn + self.is_gqa = ( + config.num_key_value_heads is not None + and config.num_key_value_heads != self.num_heads + ) + if config.fuse_rope: + assert fused_rope is not None, "fused_rope is not supported" + self.fuse_rope = config.fuse_rope + + if self.is_gqa: + logger.info( + f"use GQA - num_heads: {self.num_heads}- num_key_value_heads: {self.num_key_value_heads}" + ) + assert ( + self.num_heads % self.num_key_value_heads == 0 + ), f"num_heads: {self.num_heads}, num_key_value_heads: {self.num_key_value_heads}" + kv_hidden_size = ( + self.hidden_size // self.num_heads * self.num_key_value_heads + ) + + LinearFN = nn.Linear + self.q_proj = LinearFN( + self.hidden_size, + self.hidden_size, + bias_attr=config.use_bias, + ) + self.k_proj = LinearFN( + self.hidden_size, + self.hidden_size if not self.is_gqa else kv_hidden_size, + bias_attr=config.use_bias, + ) + self.v_proj = LinearFN( + self.hidden_size, + self.hidden_size if not self.is_gqa else kv_hidden_size, + bias_attr=config.use_bias, + ) + + if config.sequence_parallel: + self.o_proj = ErnieLinear( + self.hidden_size, + self.hidden_size, + bias_attr=config.use_bias, + ipp=self.ipp, + ) + else: + self.o_proj = LinearFN( + self.hidden_size, + self.hidden_size, + bias_attr=config.use_bias, + ) + + self.config = config + + if ( + self.config.tensor_parallel_degree > 1 + or self.config.pipeline_parallel_degree > 1 + ): + self.q_proj.weight = dist.shard_tensor( + self.q_proj.weight, + get_mesh(self.ipp), + [dist.Replicate(), dist.Shard(1)], + ) + self.k_proj.weight = dist.shard_tensor( + self.k_proj.weight, + get_mesh(self.ipp), + [dist.Replicate(), dist.Shard(1)], + ) + self.v_proj.weight = dist.shard_tensor( + self.v_proj.weight, + get_mesh(self.ipp), + [dist.Replicate(), dist.Shard(1)], + ) + if config.use_bias: + self.q_proj.bias = dist.shard_tensor( + self.q_proj.bias, + get_mesh(self.ipp), + [dist.Replicate(), dist.Shard(0)], + ) + self.k_proj.bias = dist.shard_tensor( + self.k_proj.bias, + get_mesh(self.ipp), + [dist.Replicate(), dist.Shard(0)], + ) + self.v_proj.bias = dist.shard_tensor( + self.v_proj.bias, + get_mesh(self.ipp), + [dist.Replicate(), dist.Shard(0)], + ) + self.o_proj.weight = dist.shard_tensor( + self.o_proj.weight, + get_mesh(self.ipp), + [dist.Replicate(), dist.Shard(0)], + ) + + def forward( + self, + hidden_states, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + attention_mask: Optional[paddle.Tensor] = None, + position_ids: Optional[Tuple[paddle.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + inbatch_pack_offset: Optional[Tuple[paddle.Tensor]] = None, + ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]: + if self.config.sequence_parallel: + hidden_states = dist.reshard( + hidden_states, get_mesh(self.ipp), [dist.Shard(1), dist.Replicate()] + ) + + query_states = self.q_proj(hidden_states).reshape( + shape=[0, 0, self.num_heads, self.head_dim] + ) + key_states = self.k_proj(hidden_states).reshape( + shape=[ + 0, + 0, + self.num_key_value_heads if self.is_gqa else self.num_heads, + self.head_dim, + ] + ) + value_states = self.v_proj(hidden_states).reshape( + shape=[ + 0, + 0, + self.num_key_value_heads if self.is_gqa else self.num_heads, + self.head_dim, + ] + ) + + if self.config.sequence_parallel: + query_states = paddle.transpose(query_states, [1, 0, 2, 3]) + key_states = paddle.transpose(key_states, [1, 0, 2, 3]) + value_states = paddle.transpose(value_states, [1, 0, 2, 3]) + + if self.use_recompute_attn: + assert past_key_value is None, "do not use kv cache in recompute" + assert not use_cache + attn_output, attn_weights, past_key_value = recompute( + self.rope_attn, + None, + query_states, + key_states, + value_states, + attention_mask, + position_ids, + output_attentions, + past_key_value, + use_cache, + inbatch_pack_offset, + use_reentrant=False, + ) + else: + attn_output, attn_weights, past_key_value = self.rope_attn( + mix_layer=None, + query_states=query_states, + key_states=key_states, + value_states=value_states, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + past_key_value=past_key_value, + use_cache=use_cache, + inbatch_pack_offset=inbatch_pack_offset, + ) + + if self.config.sequence_parallel: + attn_output = paddle.transpose(attn_output, [1, 0, 2]) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def rope_attn( + self, + mix_layer, + query_states, + key_states, + value_states, + attention_mask, + position_ids, + output_attentions=False, + past_key_value=None, + use_cache=False, + inbatch_pack_offset=None, + ): + if mix_layer is not None: + query_states, key_states, value_states = paddle.split(mix_layer, 3, axis=-1) + query_states_dtype = query_states.dtype + + kv_seq_len = key_states.shape[-3] + offset = 0 + if past_key_value is not None: + offset = past_key_value[0].shape[-3] + kv_seq_len += offset + + if self.config.rope_reorder: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + query_states, key_states = self.rotary_emb.apply_rotary_pos_emb( + query_states, + key_states, + cos, + sin, + position_ids=position_ids, + offset=offset if position_ids is None else 0, + ) + else: + if offset > 0 or position_ids is not None or not self.fuse_rope: + cos_sin = self.rotary_emb(kv_seq_len, position_ids).transpose( + [0, 2, 1, 3] + ) + if offset > 0 and position_ids is None: + cos_sin = cos_sin[:, offset:] + query_states, key_states = self.rotary_emb.apply_rotary( + cos_sin, query_states, key_states + ) + else: + bsz, q_len, num_heads, head_dim = query_states.shape + _, kv_seq_len, num_key_value_heads, _ = key_states.shape + if num_heads != num_key_value_heads: + query_states, _, _ = fused_rope(query_states, None, None) + key_states, _, _ = fused_rope(key_states, None, None) + else: + query_states, key_states, _ = fused_rope( + query_states, key_states, None + ) + + if use_cache: + query_states = query_states.astype(query_states_dtype) + key_states = key_states.astype(query_states_dtype) + if past_key_value is not None: + key_states = paddle.concat([past_key_value[0], key_states], axis=1) + value_states = paddle.concat([past_key_value[1], value_states], axis=1) + + past_key_value = [key_states, value_states] if use_cache else None + + attn_output, attn_weights = scaled_dot_product_attention( + query_states=query_states, + key_states=key_states, + value_states=value_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + config=self.config, + inbatch_pack_offset=inbatch_pack_offset, + training=self.training, + ) + return attn_output, attn_weights, past_key_value + + +class ErnieMoeMLP(ErnieMLP): + """_summary_ + + Args: + ErnieMoeMLP (_type_): _description_ + """ + + def __init__(self, config, ipp=0): + """ + doc + """ + disable_ffn_model_parallel = getattr( + config, "disable_ffn_model_parallel", False + ) + if disable_ffn_model_parallel: + config = deepcopy(config) + config.tensor_parallel_degree = 1 + config.sequence_parallel = False + + super().__init__(config, ipp, do_shard_tensor=not disable_ffn_model_parallel) + self.moe_dropout_prob = config.moe_dropout_prob + self.fuse_swiglu = config.fuse_swiglu + if self.fuse_swiglu: + assert fused_swiglu is not None, "fused_swiglu operator is not found." + + def redistribute_expert(self, mesh, placements): + """ + Place the experts on different devices. + """ + self.gate_proj.weight = dist.shard_tensor( + self.gate_proj.weight, mesh, placements + ) + self.up_proj.weight = dist.shard_tensor(self.up_proj.weight, mesh, placements) + self.down_proj.weight = dist.shard_tensor( + self.down_proj.weight, mesh, placements + ) + if self.config.use_bias: + self.gate_proj.bias = dist.shard_tensor( + self.gate_proj.bias, mesh, placements + ) + self.up_proj.bias = dist.shard_tensor(self.up_proj.bias, mesh, placements) + self.down_proj.bias = dist.shard_tensor( + self.down_proj.bias, mesh, placements + ) + + def forward(self, x): + + if self.fuse_swiglu: + x = fused_swiglu(self.gate_proj(x), self.up_proj(x)) + else: + x = F.silu(self.gate_proj(x)) * self.up_proj(x) + if self.moe_dropout_prob > 0: + with get_rng_state_tracker().rng_state("local_seed"): + x = F.dropout(x=x, p=self.moe_dropout_prob) + ret = self.down_proj(x) + return ret + + +class BMMLinear(nn.Layer): + + def __init__(self, experts, d_in, d_out, use_bias=False): + super().__init__() + self.weight = self.create_parameter( + [experts, d_in, d_out], dtype=paddle.get_default_dtype() + ) + if use_bias: + self.bias = self.create_parameter( + [experts, d_out], dtype=paddle.get_default_dtype(), is_bias=True + ) + else: + self.bias = None + + def forward(self, x): + """x: [num_experts, Seq, dim]""" + if self.bias is not None: + return paddle.bmm(x, self.weight) + self.bias + return paddle.bmm(x, self.weight) + + +class ErnieMoeMLPFused(nn.Layer): + + def __init__(self, config): + + assert ( + hasattr(config, "disable_ffn_model_parallel") + or config.tensor_parallel_degree == 1 + ), f"fused mlp only suport mp-moe, mp={config.tensor_parallel_degree}" + assert config.fuse_attn_ffn, "fused mlp only support fuse_attn_ffn" + super().__init__() + self.moe_dropout_prob = config.moe_dropout_prob + self.num_local_experts = config.moe_num_experts // config.moe_world_size + logger.info( + f"fused-expert-weight-shape: {[self.num_local_experts, config.hidden_size, config.intermediate_size]}" + ) + + self.up_gate_proj = BMMLinear( + self.num_local_experts, config.hidden_size, config.intermediate_size * 2 + ) + self.down_proj = BMMLinear( + self.num_local_experts, config.intermediate_size, config.hidden_size + ) + self.fuse_swiglu = config.fuse_swiglu + if self.fuse_swiglu: + assert fused_swiglu is not None, "fused_swiglu operator is not found." + + def __len__(self): + return self.num_local_experts + + def __iter__(self): + return (self for _ in range(1)) + + def forward(self, x): + if self.fuse_swiglu: + x = fused_swiglu(self.up_gate_proj(x)) + else: + gate, x = self.up_gate_proj(x).chunk(2, axis=-1) + x = F.silu(gate) * x + x = self.down_proj(x) + return x + + +class ErnieDecoderLayerAuto(nn.Layer): + """ + ErnieDecoderLayerAuto is a decoder layer in Ernie model. + It is composed of self-attention, cross-attention and feedforward layers. + """ + + def __init__(self, config, layer_idx=0, ipp=0): + """ + Initializes the ErnieBlock module. + + Args: + config (ErnieConfig): The model configuration. + layer_idx (int, optional): The index of this block in the model. Defaults to 0. + ipp (int, optional): The index of this block in the pipeline parallelism. Defaults to 0. + """ + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.ipp = ipp + self.hidden_size = config.hidden_size + self.self_attn = ErnieAttentionAuto(config, ipp) + self.use_moe = config.use_moe if hasattr(config, "use_moe") else False + if self.use_moe: + moe_layer_start_index = ( + min(config.moe_layer_start_index) + if isinstance(config.moe_layer_start_index, (tuple, list)) + else config.moe_layer_start_index + ) + moe_layer_end_index = ( + max(config.moe_layer_end_index) + if isinstance(config.moe_layer_end_index, (tuple, list)) + else config.moe_layer_end_index + ) + + if ( + self.use_moe + and ((layer_idx + 1) % config.moe_layer_interval == 0) + and layer_idx >= moe_layer_start_index + and layer_idx <= moe_layer_end_index + ): + self.create_moe_mlp_layer(layer_idx, ipp) + else: + self.mlp = ErnieMLP(config, ipp) + Norm = RMSNorm if config.use_rmsnorm else LayerNorm + self.input_layernorm = Norm(config, ipp) + self.post_attention_layernorm = Norm(config, ipp) + self.residual_add1 = FusedDropoutImpl( + config.hidden_dropout_prob, mode="upscale_in_train" + ) + self.residual_add2 = FusedDropoutImpl( + config.hidden_dropout_prob, mode="upscale_in_train" + ) + + def create_moe_mlp_layer(self, layer_idx, ipp): + _ex_cfg = deepcopy(self.config) + fc_cls = ErnieMoeMLPFused if _ex_cfg.moe_fuse_experts else ErnieMoeMLP + if _ex_cfg.moe_intermediate_size: + if isinstance(_ex_cfg.moe_intermediate_size, (tuple, list)): + assert isinstance(_ex_cfg.moe_num_experts, (tuple, list)) and len( + _ex_cfg.moe_num_experts + ) == len(_ex_cfg.moe_intermediate_size) + fc = [] + for _i, (num_experts, intermediate_size) in enumerate( + zip(_ex_cfg.moe_num_experts, _ex_cfg.moe_intermediate_size) + ): + _ex_cfg_real = deepcopy(_ex_cfg) + _ex_cfg_real.intermediate_size = intermediate_size + cur_modality_start_layer_idx = ( + self.config.moe_layer_start_index[_i] + if isinstance(self.config.moe_layer_start_index, (tuple, list)) + else self.config.moe_layer_start_index + ) + cur_modality_end_layer_idx = ( + self.config.moe_layer_end_index[_i] + if isinstance(self.config.moe_layer_end_index, (tuple, list)) + else self.config.moe_layer_end_index + ) + if ( + layer_idx >= cur_modality_start_layer_idx + and layer_idx <= cur_modality_end_layer_idx + ): + if _i == 1: + with paddle.utils.unique_name.guard( + f"mm_expert_{layer_idx}_" + ): + fc.append((num_experts, fc_cls(_ex_cfg_real))) + else: + fc.append((num_experts, fc_cls(_ex_cfg_real))) + else: + logger.info( + f"moe multimodal experts use Identity layer_idx: {layer_idx}" + ) + fc.append((num_experts, nn.Identity())) + else: + _ex_cfg.intermediate_size = _ex_cfg.moe_intermediate_size + fc = [(_ex_cfg.moe_num_experts, fc_cls(_ex_cfg))] + else: + fc = [(_ex_cfg.moe_num_experts, fc_cls(_ex_cfg))] + gate, experts, lm_gate, lm_experts = get_gate( + self.config, fc, layer_idx, self.ipp + ) + _sh_cfg = deepcopy(self.config) + + if _sh_cfg.moe_num_shared_experts > 0: + if _sh_cfg.moe_intermediate_size: + _sh_inter_size = ( + _sh_cfg.moe_intermediate_size[0] + if isinstance(_sh_cfg.moe_intermediate_size, (tuple, list)) + else _sh_cfg.moe_intermediate_size + ) + _sh_cfg.intermediate_size = ( + _sh_inter_size * _sh_cfg.moe_num_shared_experts + ) + else: + _sh_cfg.intermediate_size = ( + _sh_cfg.intermediate_size * _sh_cfg.moe_num_shared_experts + ) + _sh_cfg.disable_ffn_model_parallel = False + shared_experts = ErnieMoeMLP(_sh_cfg, ipp) + else: + shared_experts = None + + is_moe_infer = self.config.get("is_moe_infer", False) + if is_moe_infer: + raise NotImplementedError + elif self.config.moe_use_size_all2all: + raise NotImplementedError + else: + logger.info(f"moe-logging:{self.config.moe_logging}") + moe_cls = MOELayerAuto + self.mlp = moe_cls( + gate, + experts, + layer_idx=layer_idx, + shared_experts=shared_experts, + group=self.config.moe_group, + recompute=self.config.use_recompute_moe, + k=self.config.moe_k, + enable_pbr=self.config.moe_use_bpr, + all_to_all_dropout=self.config.moe_all_to_all_dropout, + group_experts=self.config.moe_group_experts, + config=self.config, + ipp=self.ipp, + ) + + def forward( + self, + hidden_states: paddle.Tensor, + attention_mask: Optional[paddle.Tensor] = None, + position_ids: Optional[paddle.Tensor] = None, + output_attentions: Optional[bool] = False, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + use_cache: Optional[bool] = False, + inbatch_pack_offset: Optional[paddle.Tensor] = None, + token_type_ids: Optional[paddle.Tensor] = None, + output_gate_logits=True, + ) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]: + """ + Args: + hidden_states (`paddle.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`paddle.Tensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `cache` key value states are returned and can be used to speed up decoding + (see `cache`). + cache (`Tuple(paddle.Tensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + (hidden_states, self_attn_weights, present_key_value, *router_loss_attn) = ( + self.self_attn( + hidden_states=hidden_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + use_cache=use_cache, + inbatch_pack_offset=inbatch_pack_offset, + ) + ) + + if ( + self.config.tensor_parallel_degree > 1 + and self.config.hidden_dropout_prob > 0.0 + ): + current_seed = ( + "local_seed" if self.config.sequence_parallel else "global_seed" + ) + with get_rng_state_tracker().rng_state(current_seed): + hidden_states = self.residual_add1(hidden_states, residual) + else: + hidden_states = self.residual_add1(hidden_states, residual) + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + + if isinstance( + self.mlp, + (MOELayerAuto), + ): + + hidden_states, _, router_loss, gate_logits = self.mlp( + hidden_states, token_type_ids + ) + else: + if self.config.sequence_parallel: + hidden_states = dist.reshard( + hidden_states, + get_mesh(self.ipp), + [dist.Shard(1), dist.Replicate()], + ) + hidden_states = self.mlp(hidden_states) + gate_logits = None + + if ( + self.config.tensor_parallel_degree > 1 + and self.config.hidden_dropout_prob > 0.0 + ): + current_seed = ( + "local_seed" if self.config.sequence_parallel else "global_seed" + ) + with get_rng_state_tracker().rng_state(current_seed): + hidden_states = self.residual_add2(hidden_states, residual) + else: + hidden_states = self.residual_add2(hidden_states, residual) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + if hasattr(self.config, "use_moe") and self.config.use_moe: + if router_loss_attn: + router_loss_attn = router_loss_attn[0] + router_loss = router_loss + router_loss_attn + + if isinstance(self.mlp, (MOELayerAuto)): + outputs += (router_loss,) + else: + outputs += (paddle.zeros([1], dtype=paddle.float32),) + + if output_gate_logits: + outputs += (gate_logits,) + + if type(outputs) is tuple and len(outputs) == 1: + outputs = outputs[0] + return outputs + + +class ErniePretrainedModelAuto(PretrainedModel): + """ + ErniePretrainedModelAuto is a pretrained model class for Ernie model. + It is composed of a encoder and a decoder. + """ + + config_class = ErnieMoEConfig + base_model_prefix = "ernie" + + @classmethod + def _get_name_mappings(cls, config: ErnieMoEConfig) -> StateDictNameMapping: + + mappings: StateDictNameMapping = [] + model_mappings = [ + ["embed_tokens.weight"], + ["norm.weight"], + ] + for layer_index in range( + config.num_hidden_layers + if not config.remove_tail_layer + else config.num_hidden_layers - 1 + ): + layer_mappings = [ + [ + f"layers.{layer_index}.self_attn.q_proj.weight", + None, + "transpose", + ], + [ + f"layers.{layer_index}.self_attn.k_proj.weight", + None, + "transpose", + ], + [ + f"layers.{layer_index}.self_attn.v_proj.weight", + None, + "transpose", + ], + [ + f"layers.{layer_index}.self_attn.o_proj.weight", + None, + "transpose", + ], + [f"layers.{layer_index}.self_attn.rotary_emb.inv_freq"], + [f"layers.{layer_index}.mlp.gate_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.down_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.up_proj.weight", None, "transpose"], + [f"layers.{layer_index}.input_layernorm.weight"], + [f"layers.{layer_index}.post_attention_layernorm.weight"], + ] + model_mappings.extend(layer_mappings) + + init_name_mappings(mappings=model_mappings) + if "ErnieModelAuto" not in config.architectures: + for mapping in model_mappings: + mapping[0] = "model." + mapping[0] + mapping[1] = "ernie." + mapping[1] + model_mappings.append(["lm_head.weight", "lm_head.weight", "transpose"]) + + mappings = [ + StateDictNameMapping(*mapping, index=index) + for index, mapping in enumerate(model_mappings) + ] + return mappings + + @classmethod + def _get_tensor_parallel_mappings(cls, config, is_split=True): + + from paddleformers.transformers.conversion_utils import split_or_merge_func + + fn = split_or_merge_func( + is_split=is_split, + tensor_parallel_degree=config.tensor_parallel_degree, + tensor_parallel_rank=config.tensor_parallel_rank, + num_attention_heads=config.num_attention_heads, + ) + + def get_tensor_parallel_split_mappings(num_layers): + final_actions = {} + base_actions = { + "layers.0.self_attn.q_proj.weight": partial(fn, is_column=True), + "layers.0.self_attn.k_proj.weight": partial(fn, is_column=True), + "layers.0.self_attn.v_proj.weight": partial(fn, is_column=True), + "layers.0.mlp.gate_proj.weight": partial(fn, is_column=True), + "layers.0.mlp.up_proj.weight": partial(fn, is_column=True), + "lm_head.weight": partial(fn, is_column=not config.tie_word_embeddings), + "embed_tokens.weight": partial(fn, is_column=False), + "layers.0.self_attn.o_proj.weight": partial(fn, is_column=False), + "layers.0.mlp.down_proj.weight": partial(fn, is_column=False), + } + if config.use_bias: + base_actions.update( + { + "layers.0.self_attn.q_proj.bias": partial(fn, is_column=True), + "layers.0.self_attn.k_proj.bias": partial(fn, is_column=True), + "layers.0.self_attn.v_proj.bias": partial(fn, is_column=True), + "layers.0.mlp.gate_proj.bias": partial(fn, is_column=True), + "layers.0.mlp.up_proj.bias": partial(fn, is_column=True), + "lm_head.bias": partial(fn, is_column=True), + } + ) + for key, action in base_actions.items(): + if "layers.0." in key: + for i in range(num_layers): + final_actions[key.replace("layers.0.", f"layers.{i}.")] = action + final_actions[key] = action + + return final_actions + + mappings = get_tensor_parallel_split_mappings( + config.num_hidden_layers + if not config.remove_tail_layer + else config.num_hidden_layers - 1 + ) + + return mappings + + def init_weights(self, layer): + """Initialization hook""" + if self.config.tensor_parallel_degree > 1: + rng_tracker = get_rng_state_tracker().rng_state + else: + rng_tracker = contextlib.nullcontext + + if isinstance( + layer, + ( + ErnieLMHead, + nn.Embedding, + nn.Linear, + paddle.incubate.nn.FusedLinear, + ), + ): + + with rng_tracker(): + dtype = paddle.get_default_dtype() + paddle.set_default_dtype("float32") + if layer.weight._is_initialized(): + if layer.weight.is_dist(): + layer.weight._local_value().set_value( + paddle.randn( + layer.weight._local_shape, dtype=layer.weight.dtype + ).scale(self.config.initializer_range) + ) + else: + layer.weight.set_value( + paddle.randn( + layer.weight.shape, dtype=layer.weight.dtype + ).scale(self.config.initializer_range) + ) + paddle.set_default_dtype(dtype) + logger.info( + f"dist-init-fc: shape={layer.weight.shape}, " + f" range={self.config.initializer_range}," + f' type={type(layer)},norm={layer.weight.astype("float32").norm()}' + ) + + elif isinstance(layer, RotaryEmbedding): + head_dim = self.config.hidden_size // self.config.num_attention_heads + inv_freq = 1.0 / ( + layer.base ** (np.arange(0, head_dim, 2).astype("float32") / head_dim) + ) + + t = np.arange(layer.max_position_embeddings, dtype="float32") + freqs = np.einsum("i,j->ij", t, inv_freq) + emb = np.concatenate([freqs, freqs], axis=-1) + cos_cached = np.cos(emb)[:, :] + sin_cached = np.sin(emb)[:, :] + layer.cos_cached.set_value(cos_cached) + layer.sin_cached.set_value(sin_cached) + elif isinstance(layer, Top2Gate): + if not hasattr(layer, "weight"): + return + with rng_tracker("model_parallel_rng"): + dtype = paddle.get_default_dtype() + paddle.set_default_dtype("float32") + if self.config.moe_group_experts: + if layer.weight._is_initialized(): + layer.weight.set_value( + paddle.randn( + layer.weight.shape, dtype=layer.weight.dtype + ).scale(self.config.initializer_range) + ) + else: + if layer.weight._is_initialized(): + granularity = ( + 1 + if self.config.moe_intermediate_size == 0 + else self.config.intermediate_size + // self.config.moe_intermediate_size + ) + layer.weight.set_value( + paddle.randn( + [ + self.config.hidden_size, + self.config.moe_num_experts // granularity, + ], + dtype="float32", + ) + .scale(self.config.initializer_range) + .repeat_interleave(granularity, axis=-1) + ) + logger.info( + f"dist-init-moe_gate: shape={layer.weight.shape}, dtype={layer.weight.dtype} " + f"range={self.config.initializer_range},type={type(layer)}, " + f'norm={layer.weight.astype("float32").norm()}' + ) + + +@register_base_model +class ErnieModelAuto(ErniePretrainedModelAuto): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`ErnieDecoderLayerAuto`] + Args: + config: ErnieMoEConfig + """ + + def __init__(self, config: ErnieMoEConfig): + if hasattr(config, "use_moe") and config.use_moe: + if config.moe_group in {"mp", "model", "tp", "mpdp"}: + assert config.sequence_parallel + logger.info( + f"disable FFN tensor model parallel, moe-group={config.moe_group}" + ) + config.disable_ffn_model_parallel = True + + config.moe_group = _parse_moe_group(config.moe_group) + if config.moe_group in fleet.auto.get_mesh().dim_names: + config.moe_world_size = fleet.auto.get_mesh().get_dim_size( + config.moe_group + ) + if config.moe_world_size < 0: + config.moe_world_size = 1 + else: + config.moe_world_size = 1 + + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.hidden_size = config.hidden_size + self.config = config + + self.embed_tokens = nn.Embedding( + self.vocab_size, + self.hidden_size, + ) + + if ( + self.config.tensor_parallel_degree > 1 + or self.config.pipeline_parallel_degree > 1 + ): + if not in_auto_parallel_align_mode(): + self.embed_tokens.weight = dist.shard_tensor( + self.embed_tokens.weight, + get_mesh(), + [dist.Replicate(), dist.Shard(1)], + ) + + layers_list = [] + + def get_layer_pp_info(ipp): + mesh = fleet.auto.get_mesh() + if is_pp_enable() is False: + return None, False + else: + pp_degree = mesh.get_dim_size("pp") + layer_num = ( + config.num_hidden_layers - 1 + if config.remove_tail_layer + else config.num_hidden_layers + ) + layer_per_stage = math.ceil(layer_num / pp_degree) + input_need_reshard = ipp % layer_per_stage == 0 + return ipp // layer_per_stage, input_need_reshard + + self.next_pp_stage_indexes = [] + for layer_idx in range( + config.num_hidden_layers - 1 + if config.remove_tail_layer + else config.num_hidden_layers + ): + pp_stage_id, input_need_reshard = get_layer_pp_info(layer_idx) + layers_list.append(ErnieDecoderLayerAuto(config, layer_idx, pp_stage_id)) + if input_need_reshard: + self.next_pp_stage_indexes.append(layer_idx) + self.layers = nn.LayerList(layers_list) + Norm = RMSNorm if config.use_rmsnorm else LayerNorm + + self.norm = Norm(config, -1) + + self.gradient_checkpointing = False + + self.placements = ( + [dist.Shard(1), dist.Shard(0)] + if self.config.sequence_parallel + else [dist.Shard(0), dist.Replicate()] + ) + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @classmethod + def _prepare_decoder_attention_mask( + cls, attention_mask, input_shape, past_key_values_length, dtype + ): + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, past_key_values_length=past_key_values_length, dtype=dtype + ) + + if attention_mask is not None: + expanded_attn_mask = _expand_mask( + attention_mask, dtype, tgt_length=input_shape[-1] + ) + combined_attention_mask = ( + expanded_attn_mask + if combined_attention_mask is None + else expanded_attn_mask + combined_attention_mask + ) + combined_attention_mask = paddle.maximum( + combined_attention_mask.astype(dtype), + paddle.to_tensor(float(finfo(dtype).min), dtype=dtype), + ) + return combined_attention_mask + + def recompute_training( + self, + layer_module, + hidden_states, + attention_mask, + position_ids, + output_attentions, + past_key_value, + use_cache, + inbatch_pack_offset, + token_type_ids, + ): + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_gate_logits=False) + + return custom_forward + + hidden_states = recompute( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + position_ids, + output_attentions, + past_key_value, + use_cache, + inbatch_pack_offset, + token_type_ids, + use_reentrant=False, + ) + return hidden_states + + def forward( + self, + input_ids=None, + position_ids=None, + attention_mask=None, + inputs_embeds=None, + use_cache=None, + past_key_values=None, + output_attentions=False, + output_hidden_states=None, + return_dict=False, + inbatch_pack_offset=None, + token_type_ids=None, + **kwargs, + ): + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" + ) + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError( + "You have to specify either decoder_input_ids or decoder_inputs_embeds" + ) + + if past_key_values is None: + past_key_values = tuple([None] * len(self.layers)) + + seq_length_with_past = seq_length + cache_length = 0 + + if past_key_values[0] is not None: + cache_length = paddle.shape(past_key_values[0][0])[1] + seq_length_with_past += cache_length + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids).astype( + self.embed_tokens.weight.dtype + ) + + global_mesh = global_mesh_starts_with_pp() + if self.config.sequence_parallel: + inputs_embeds = paddle.transpose(inputs_embeds, [1, 0, 2]) + + if position_ids is not None: + position_ids = dist.shard_tensor( + position_ids, + global_mesh, + [dist.Replicate() for _ in range(len(global_mesh._shape))], + ) + can_use_fa = self.config.use_flash_attn and flash_attention is not None + can_mem_eff_attn = ( + self.config.use_mem_eff_attn and inbatch_pack_offset is not None + ) + if can_use_fa or can_mem_eff_attn: + if attention_mask is not None: + attention_mask = None + + elif attention_mask is None: + attention_mask = paddle.ones( + (batch_size, seq_length_with_past), dtype=paddle.bool + ) + + if attention_mask is not None: + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, + (batch_size, seq_length), + cache_length, + inputs_embeds.dtype, + ) + attention_mask = dist.shard_tensor( + attention_mask, + global_mesh, + [dist.Replicate() for _ in range(len(global_mesh._shape))], + ) + + hidden_states = inputs_embeds + if self.config.tensor_parallel_degree > 1: + hidden_states = dist.reshard(hidden_states, get_mesh(0), self.placements) + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + all_router_loss = None + if hasattr(self.config, "use_moe") and self.config.use_moe: + all_router_loss = paddle.to_tensor(0.0) + all_router_loss = dist.shard_tensor( + all_router_loss, get_mesh(0), dist.Replicate() + ) + all_gate_logits = () if hasattr(self.config, "use_moe") else None + for idx, (decoder_layer) in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = ( + past_key_values[idx] if past_key_values is not None else None + ) + + has_gradient = not hidden_states.stop_gradient + ipp = decoder_layer.ipp + if not is_pp_enable(): + position_ids_input = position_ids + attention_mask_input = attention_mask + token_type_ids_input = token_type_ids + else: + if position_ids is not None: + position_ids_input = dist.reshard( + position_ids, + get_mesh(ipp), + [dist.Replicate(), dist.Replicate()], + ) + else: + position_ids_input = position_ids + attention_mask_input = ( + dist.reshard( + attention_mask, + get_mesh(ipp), + [dist.Replicate(), dist.Replicate()], + ) + if attention_mask is not None + else None + ) + token_type_ids_input = ( + dist.reshard( + token_type_ids, + get_mesh(ipp), + [dist.Replicate(), dist.Replicate()], + ) + if token_type_ids is not None + else None + ) + + if idx in self.next_pp_stage_indexes: + hidden_states = dist.reshard( + hidden_states, + get_mesh(ipp), + self.placements, + ) + if hasattr(self.config, "use_moe") and self.config.use_moe: + all_router_loss = dist.reshard( + all_router_loss, + get_mesh(ipp), + [dist.Replicate()], + ) + if self.config.use_recompute and has_gradient: + layer_outputs = self.recompute_training( + decoder_layer, + hidden_states, + attention_mask_input, + position_ids_input, + output_attentions, + past_key_value, + use_cache, + inbatch_pack_offset, + token_type_ids_input, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask_input, + position_ids_input, + output_attentions, + past_key_value, + use_cache, + inbatch_pack_offset, + token_type_ids_input, + ) + + if isinstance(layer_outputs, (tuple, list)): + hidden_states = layer_outputs[0] + else: + hidden_states = layer_outputs + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + if hasattr(self.config, "use_moe") and self.config.use_moe: + if not (self.config.use_recompute and has_gradient): + layer_outputs, gate_logits = layer_outputs[:-1], layer_outputs[-1] + all_gate_logits = all_gate_logits + (gate_logits,) + router_loss = layer_outputs[-1] + all_router_loss += router_loss + + if use_cache and not (hasattr(self.config, "use_moe") and self.config.use_moe): + hidden_states = paddle.unsqueeze(hidden_states[:, -1, :], 1) + + if self.config.pipeline_parallel_degree > 1: + hidden_states = dist.reshard( + hidden_states, + get_mesh(-1), + self.placements, + ) + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_cache, + all_hidden_states, + all_self_attns, + all_router_loss, + all_gate_logits, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=None, + router_loss=all_router_loss, + gate_logits=all_gate_logits, + ) + + +class ErniePretrainingCriterionBase(paddle.nn.Layer): + """ + Criterion for Ernie. + It calculates the final loss. + """ + + def __init__(self, config, return_tuple=True): + super(ErniePretrainingCriterionBase, self).__init__() + self.ignored_index = getattr(config, "ignored_index", -100) + self.config = config + self.return_tuple = return_tuple + self.enable_parallel_cross_entropy = ( + config.tensor_parallel_degree > 1 and config.tensor_parallel_output + ) + + self.loss_func = paddle.nn.CrossEntropyLoss( + reduction="none", + ) + + def forward(self, prediction_scores, masked_lm_labels): + if self.config.use_sparse_head_and_loss_fn: + hidden_states, outlinear_weight, outlinear_bias = prediction_scores + + if self.config.sequence_parallel: + masked_lm_labels, sparse_label_idx = ( + sequence_parallel_sparse_mask_labels( + masked_lm_labels, self.ignored_index + ) + ) + else: + masked_lm_labels = masked_lm_labels.flatten() + sparse_label_idx = paddle.nonzero( + masked_lm_labels != self.ignored_index + ).flatten() + masked_lm_labels = paddle.take_along_axis( + masked_lm_labels, sparse_label_idx, axis=0 + ) + + hidden_states = hidden_states.reshape([-1, hidden_states.shape[-1]]) + hidden_states = paddle.take_along_axis( + hidden_states, sparse_label_idx.reshape([-1, 1]), axis=0 + ) + + if self.config.use_recompute_loss_fn: + res = recompute( + self.forward_impl_with_calc_logits, + masked_lm_labels, + hidden_states, + outlinear_weight, + outlinear_bias, + sparse_label_idx, + ) + else: + logits = calc_lm_head_logits( + self.config, + hidden_states, + outlinear_weight, + outlinear_bias, + sparse_label_idx, + ) + res = self.forward_impl(logits, masked_lm_labels) + elif self.config.use_recompute_loss_fn: + assert isinstance(prediction_scores, tuple) and len(prediction_scores) in [ + 3, + 4, + ] + res = recompute( + self.forward_impl_with_calc_logits, masked_lm_labels, *prediction_scores + ) + else: + res = self.forward_impl(prediction_scores, masked_lm_labels) + + return res + + def forward_impl_with_calc_logits( + self, + masked_lm_labels, + hidden_states, + outlinear_weight, + outlinear_bias, + sparse_label_idx=None, + tensor_parallel_output=None, + ): + + logits = calc_lm_head_logits( + self.config, + hidden_states, + outlinear_weight, + outlinear_bias, + sparse_label_idx, + tensor_parallel_output, + ) + + return self.forward_impl(logits, masked_lm_labels) + + def loss_impl(self, prediction_scores, masked_lm_labels): + """extract loss impl for subbatch""" + masked_lm_loss = self.loss_func( + prediction_scores.astype("float32"), masked_lm_labels.unsqueeze(-1) + ) + return masked_lm_loss + + def forward_impl(self, prediction_scores, masked_lm_labels): + + with paddle.amp.auto_cast(False): + if self.config.use_sparse_head_and_loss_fn and prediction_scores.shape[ + 0 + ] > self.config.get("loss_subbatch_seqlen", 32768): + sb_loss_func = subbatch( + self.loss_impl, + [0, 1], + [0, 0], + self.config.get("loss_subbatch_seqlen", 32768), + 0, + ) + masked_lm_loss = sb_loss_func(prediction_scores, masked_lm_labels) + else: + masked_lm_loss = self.loss_impl(prediction_scores, masked_lm_labels) + lossmask = masked_lm_labels != self.ignored_index + + if (~lossmask).all(): + logger.warning( + f"encounter empty span when calculate loss, ignored_index={self.ignored_index}" + ) + loss = paddle.mean(masked_lm_loss) * 0.0 + loss_sum = masked_lm_loss.sum().detach() + else: + lossmask_ = lossmask.reshape([-1]).cast(paddle.float32) + masked_lm_loss_ = paddle.sum( + masked_lm_loss.cast(paddle.float32).reshape([-1]) * lossmask_ + ) + loss = masked_lm_loss_ / lossmask_.sum() + loss_sum = masked_lm_loss_.sum().detach() + + if not self.return_tuple: + if self.training: + return loss + return loss_sum + return loss, loss_sum + + +class ErniePretrainingCriterion(ErniePretrainingCriterionBase): + """ + Criterion for Ernie. + It calculates the final loss. + """ + + def __init__(self, config, return_tuple=True): + super(ErniePretrainingCriterion, self).__init__( + config, return_tuple=return_tuple + ) + + def forward(self, prediction_scores, masked_lm_labels, router_loss=None): + """ + calculates the final loss + """ + res = super().forward( + prediction_scores, + masked_lm_labels, + ) + if self.return_tuple: + loss, loss_sum = res + else: + loss, loss_sum = res, None + if router_loss is not None and not in_auto_parallel_align_mode(): + global_mesh = global_mesh_starts_with_pp() + if self.config.pipeline_parallel_degree > 1: + loss = dist.reshard( + loss, + global_mesh, + [dist.Replicate() for _ in range(len(global_mesh._shape))], + ) + router_loss = dist.reshard( + router_loss, + global_mesh, + [dist.Replicate() for _ in range(len(global_mesh._shape))], + ) + loss = loss + router_loss - router_loss.detach() + return loss, loss_sum + + +class ErnieLMHead(nn.Layer): + """ + ErnieLMHead is the linear layer used to project hidden state of decoder into word embeddings. + """ + + def __init__(self, config): + super(ErnieLMHead, self).__init__() + self.config = config + vocab_size = config.vocab_size + self.weight = self.create_parameter( + shape=( + [vocab_size, config.hidden_size] + if config.tie_word_embeddings + else [config.hidden_size, vocab_size] + ), + dtype=paddle.get_default_dtype(), + ) + + if ( + self.config.tensor_parallel_degree > 1 + or self.config.pipeline_parallel_degree > 1 + ): + self.weight = dist.shard_tensor( + self.weight, + get_mesh(-1), + [dist.Replicate(), dist.Shard(1)], + ) + + logger.info( + f"output-weight:{self.weight.shape} config.tie_word_embeddings={config.tie_word_embeddings}" + ) + if config.weight_share_add_bias and config.use_bias: + self.bias = self.create_parameter( + shape=[vocab_size], + dtype=paddle.get_default_dtype(), + attr=paddle.ParamAttr( + initializer=paddle.nn.initializer.constant.Constant(0.0) + ), + ) + if ( + self.config.tensor_parallel_degree > 1 + or self.config.pipeline_parallel_degree > 1 + ): + self.bias = dist.shard_tensor( + self.bias, + get_mesh(-1), + [dist.Replicate(), dist.Shard(0)], + ) + else: + self.bias = None + + self.weight.is_distributed = ( + True if (vocab_size != config.vocab_size) else False + ) + if config.weight_share_add_bias and config.use_bias: + self.bias.is_distributed = ( + True if (vocab_size != config.vocab_size) else False + ) + + if self.weight.is_distributed: + self.weight.split_axis = 1 + if ( + config.weight_share_add_bias + and config.use_bias + and self.bias.is_distributed + ): + self.bias.split_axis = 0 + + if self.config.use_recompute_loss_fn: + logger.info( + "Using recompute_loss_fn, the calculation of logits will be moved into " + "loss_fn for memory optimization" + ) + + def forward(self, hidden_states, tensor_parallel_output=None): + + if self.config.use_recompute_loss_fn or self.config.use_sparse_head_and_loss_fn: + out_tensors = ( + (hidden_states, self.weight, self.bias) + if tensor_parallel_output is None + else (hidden_states, self.weight, self.bias, tensor_parallel_output) + ) + + return out_tensors + + return calc_lm_head_logits( + self.config, + hidden_states, + self.weight, + self.bias, + None, + tensor_parallel_output, + ) + + +class ErnieForCausalLMAuto(ErniePretrainedModelAuto): + """ + ErnieForCausalLMAuto is the model class for causal language modeling. + """ + + _keys_to_ignore_on_load_missing = [r"lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + + if config.sequence_parallel: + logger.info(f"using sequence_parallel, input seqlen={config.seqlen}") + if config.using_dynamic_sequence_length: + assert ( + not config.micro_batch_size + ), "sequence-parallel needs micro_batch_size setting when using dygramic_sequnence_length" + else: + assert config.seqlen is not None + + assert ( + config.tensor_parallel_degree > 1 + ), f"sequence-parallel needs mp>1, got mp={config.tensor_parallel_degree}" + + new_initializer_range = math.sqrt(0.3333 / config.hidden_size) + logger.info( + f"change initializer-range from {config.initializer_range} to {new_initializer_range}" + ) + config.initializer_range = new_initializer_range + self.config = config + self.ernie = ErnieModelAuto(config) + self.lm_head = ErnieLMHead(config) + self.criterion = ErniePretrainingCriterion(config) + + self.tie_weights() + + if self.config.use_rmsnorm: + if self.config.fuse_rms_norm: + logger.info("Use fusedRMSNorm") + else: + logger.info("Use normal RMSNorm") + else: + logger.info("Use normal LayerNorm") + + def _post_init(self, original_init, *args, **kwargs): + """ + Initialize weights and apply final processing + """ + super()._post_init(self, original_init, *args, **kwargs) + factor = 1 / math.sqrt(2 * self.config.num_hidden_layers) + logger.info(f"using post init div: factor:{factor}") + + def scale_by_factor_if_valid(w): + if w.is_dist() and w._is_initialized(): + w.scale_(factor) + + if hasattr(self.config, "use_moe") and self.config.use_moe: + with paddle.no_grad(): + for left in self.ernie.layers: + if isinstance( + left.self_attn.o_proj, + (MOELayerAuto), + ): + for e in left.self_attn.o_proj.experts: + if isinstance(e, ErnieMoeMLP): + scale_by_factor_if_valid(e.weight) + else: + scale_by_factor_if_valid(left.self_attn.o_proj.weight) + + if isinstance( + left.mlp, + (MOELayerAuto), + ): + for e in left.mlp.experts: + if isinstance(e, ErnieMoeMLP): + scale_by_factor_if_valid(e.down_proj.weight) + else: + scale_by_factor_if_valid(left.mlp.down_proj.weight) + else: + with paddle.no_grad(): + for left in self.ernie.layers: + scale_by_factor_if_valid(left.self_attn.o_proj.weight) + scale_by_factor_if_valid(left.mlp.down_proj.weight) + + def get_input_embeddings(self): + + return self.ernie.embed_tokens + + def set_input_embeddings(self, value): + + self.ernie.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.ernie = decoder + + def get_decoder(self): + return self.ernie + + @staticmethod + def prepare_attention_mask_for_generation(input_ids, pad_token_id, eos_token_id): + is_pad_token_in_inputs_ids = (pad_token_id is not None) and paddle.any( + input_ids == pad_token_id + ).numpy().item() + is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ( + (eos_token_id is not None) and (pad_token_id != eos_token_id) + ) + if is_pad_token_in_inputs_ids and is_pad_token_not_equal_to_eos_token_id: + attention_mask = (input_ids != pad_token_id).astype("int64") + else: + attention_mask = paddle.ones_like(input_ids, dtype="int64") + return attention_mask + + def prepare_inputs_for_generation( + self, + input_ids, + use_cache=False, + past_key_values=None, + inputs_embeds=None, + **kwargs, + ): + if past_key_values: + input_ids = input_ids[:, -1:] + + attention_mask = kwargs.get("attention_mask", None) + + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": True, + "attention_mask": attention_mask, + "return_dict": True, + } + ) + return model_inputs + + @staticmethod + def update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=False + ): + if ( + isinstance(outputs, tuple) + and len(outputs) > 1 + and not isinstance(outputs[1], paddle.Tensor) + ): + model_kwargs["past_key_values"] = outputs[1] + + if ( + isinstance(outputs, CausalLMOutputWithCrossAttentions) + and "past_key_values" in outputs + ): + model_kwargs["past_key_values"] = outputs.past_key_values + + if ( + "token_type_ids" in model_kwargs + and model_kwargs["token_type_ids"] is not None + ): + token_type_ids = model_kwargs["token_type_ids"] + model_kwargs["token_type_ids"] = paddle.concat( + [token_type_ids, token_type_ids[:, -1:]], axis=-1 + ) + + if not is_encoder_decoder: + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = paddle.concat( + [ + attention_mask, + paddle.ones([attention_mask.shape[0], 1], dtype="int64"), + ], + axis=-1, + ) + if "role_ids" in model_kwargs and model_kwargs["role_ids"] is not None: + role_ids = model_kwargs["role_ids"] + model_kwargs["role_ids"] = paddle.concat( + [role_ids, role_ids[:, -1:]], axis=-1 + ) + + return model_kwargs + + def forward( + self, + input_ids, + labels=None, + position_ids=None, + attention_mask=None, + inputs_embeds=None, + use_cache=False, + past_key_values=None, + output_attentions=None, + output_hidden_states=None, + return_dict=False, + ignored_index=0, + inbatch_pack_offset=None, + token_type_ids=None, + ): + if isinstance(input_ids, list): + input_ids, labels = input_ids[:2] + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + outputs = self.ernie( + input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + past_key_values=past_key_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + inbatch_pack_offset=inbatch_pack_offset, + token_type_ids=token_type_ids, + ) + + hidden_states = outputs.last_hidden_state + + logits = self.lm_head( + hidden_states, + ) + + if return_dict: + if labels is not None: + loss, _ = self.criterion(logits, labels) + else: + loss = None + return CausalLMOutputWithCrossAttentionsAuto( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_loss=outputs.router_loss if self.config.use_moe else None, + ) + + assert labels is not None + router_loss = ( + outputs.router_loss + if hasattr(self.config, "use_moe") and self.config.use_moe + else None + ) + return self.criterion(logits, labels, router_loss) diff --git a/examples/pre-training/models/ernie/modeling_auto_pp.py b/examples/pre-training/models/ernie/modeling_auto_pp.py new file mode 100644 index 00000000..bf192e40 --- /dev/null +++ b/examples/pre-training/models/ernie/modeling_auto_pp.py @@ -0,0 +1,596 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Paddle Ernie model""" +import math +import logging + + +import paddle +import paddle.distributed as dist +from paddle import nn +from paddle.distributed import fleet +from paddle.distributed.auto_parallel.pipelining.schedules import ( + Schedule1F1B, + ScheduleFThenB, + ScheduleVPP, +) +from paddle.distributed.auto_parallel.pipelining.stage import PipelineStage + +from paddle.distributed.fleet.utils import recompute + + +from models.moe.moe_utils_auto import get_mesh + +from .modeling_auto import ( + _parse_moe_group, + ErnieDecoderLayerAuto, + ErniePretrainedModelAuto, + LayerNorm, + RMSNorm, + ErniePretrainingCriterion, + ErnieLMHead, +) + +from paddle.distributed import in_auto_parallel_align_mode + + +logger = logging.getLogger(__name__) + +try: + from paddle.nn.functional.flash_attention import flash_attention + + logger.warning( + "Use flash attention in scaled-dot-product. Attention mask is deprecated" + ) +except (ImportError, ModuleNotFoundError): + flash_attention = None + + +__all__ = [ + "get_ernie_pp_schedule", + "ErnieForCausalLMAutoPP", +] + + +def parse_args(args): + hidden_states, attention_mask, position_ids = None, None, None + if isinstance(args, tuple): + if len(args) == 3: + hidden_states, attention_mask, position_ids = args + elif len(args) == 2: + hidden_states, attention_mask = args + elif len(args) == 1: + hidden_states = args[0] + else: + hidden_states = args + if position_ids is not None: + position_ids.stop_gradient = True + + if attention_mask is not None: + attention_mask.stop_gradient = True + + return hidden_states, attention_mask, position_ids + + +def return_args(hidden_states, attention_mask=None, position_ids=None): + + ret = (hidden_states,) + + if attention_mask is not None: + ret += (attention_mask.clone(),) + if position_ids is not None: + ret += (position_ids.clone(),) + if len(ret) == 1: + ret = ret[0] + + return ret + + +def global_mesh_starts_with_pp(): + + mesh = fleet.auto.get_mesh() + if "pp" in mesh.dim_names: + return mesh.get_mesh_with_dim("pp") + else: + return mesh + + +class ErnieChunk(nn.Layer): + def __init__(self, layers=None, is_first=False): + + super(ErnieChunk, self).__init__() + self.layers = layers + self.is_first = is_first + + def forward(self, *args, **kwargs): + """ + Forward function of the model. + + Args: + *args (tuple, optional): Tuple containing input tensors. If is_first is True, + input_ids, attention_mask and position_ids are required; otherwise, + it should be a tuple of output tensors from previous layer. Default None. + **kwargs (dict, optional): Dictionary containing input tensors. If is_first is False, + input_ids, attention_mask and position_ids are required; otherwise, it should be + an empty dictionary. Default None. + + Returns: + tuple (list[Tensor], Tensor, Tensor): Tuple containing output tensors from each decoder layer. + The first item is a list of output tensors from each decoder layer, the second item is the last + hidden state of the encoder, and the third item is the last position encoding index. + """ + if self.is_first: + input_ids = kwargs.get("input_ids") + attention_mask = kwargs.get("attention_mask") + position_ids = kwargs.get("position_ids") + outputs = tuple([input_ids, attention_mask, position_ids]) + for idx, (decoder_layer) in enumerate(self.layers): + outputs = decoder_layer(outputs) + return outputs + else: + outputs = args + for idx, (decoder_layer) in enumerate(self.layers): + outputs = decoder_layer(outputs) + return outputs + + +def manual_model_split(model, stage_idx, group, mode, pp_degree): + + num_hidden_layers = model.config.num_hidden_layers + virtual_pp_degree = model.config.virtual_pp_degree if mode == "VPP" else 1 + chunk_size = num_hidden_layers // virtual_pp_degree // pp_degree + chunk_num = virtual_pp_degree * pp_degree + layer_lists = None + + layer_lists = model.layers + + def _build_stage(model, stage_idx, group): + new_model = None + if stage_idx == 0: + new_model = ErnieChunk(layer_lists[:chunk_size], is_first=True) + else: + new_model = ErnieChunk( + layer_lists[stage_idx * chunk_size : (stage_idx + 1) * chunk_size], + is_first=False, + ) + stage = PipelineStage(new_model, stage_idx, chunk_num, group=group) + return stage + + stages = [] + for i in range(virtual_pp_degree): + stage = _build_stage(model, stage_idx + i * pp_degree, group) + stages.append(stage) + return stages + + +def get_ernie_pp_schedule(model, n_microbatches, loss_fn, mode, pp_degree, group): + + assert mode in ["VPP", "1F1B", "FThenB"] + stages = manual_model_split(model, group.rank, group, mode, pp_degree) + if mode == "VPP": + schedule = ScheduleVPP(stages, n_microbatches=n_microbatches, loss_fn=loss_fn) + elif mode == "1F1B": + schedule = Schedule1F1B( + stages[0], n_microbatches=n_microbatches, loss_fn=loss_fn + ) + else: + schedule = ScheduleFThenB( + stages[0], n_microbatches=n_microbatches, loss_fn=loss_fn + ) + return schedule + + +class ErnieDecoderLayerAutoPP(nn.Layer): + def __init__(self, config, layer_idx=0, ipp=0): + """ + Initializes the model. + + Args: + config (ErnieConfig): The configuration of the model. + layer_idx (int, optional): The index of the decoder layer. Defaults to 0. + ipp (int, optional): The index of the inner parallelism dimension. Defaults to 0. + + Returns: + None. + """ + if hasattr(config, "use_moe") and config.use_moe: + if config.moe_group in {"mp", "model", "tp", "mpdp"}: + assert config.sequence_parallel + logger.info( + f"disable FFN tensor model parallel, moe-group={config.moe_group}" + ) + config.disable_ffn_model_parallel = True + + config.moe_group = _parse_moe_group(config.moe_group) + if config.moe_group in fleet.auto.get_mesh().dim_names: + config.moe_world_size = fleet.auto.get_mesh().get_dim_size( + config.moe_group + ) + if config.moe_world_size < 0: + config.moe_world_size = 1 + else: + config.moe_world_size = 1 + + super().__init__() + self.config = config + + if hasattr(config, "use_moe") and config.use_moe: + if config.moe_group in {"mp", "model", "tp", "mpdp"}: + assert config.sequence_parallel + logger.info( + f"disable FFN tensor model parallel, moe-group={config.moe_group}" + ) + config.disable_ffn_model_parallel = True + + config.moe_group = _parse_moe_group(config.moe_group) + if config.moe_group in fleet.auto.get_mesh().dim_names: + config.moe_world_size = fleet.auto.get_mesh().get_dim_size( + config.moe_group + ) + if config.moe_world_size < 0: + config.moe_world_size = 1 + else: + config.moe_world_size = 1 + + self.layer_idx = layer_idx + self.ipp = ipp + self.placements = ( + [dist.Shard(1), dist.Shard(0)] + if self.config.sequence_parallel + else [dist.Shard(0), dist.Replicate()] + ) + self.embed_tokens = None + self.norm = None + self.lm_head = None + if layer_idx == 0: + self.vocab_size = config.vocab_size + self.hidden_size = config.hidden_size + self.embed_tokens = nn.Embedding( + self.vocab_size, + self.hidden_size, + ) + if ( + self.config.tensor_parallel_degree > 1 + or self.config.pipeline_parallel_degree > 1 + ): + if not in_auto_parallel_align_mode(): + self.embed_tokens.weight = dist.shard_tensor( + self.embed_tokens.weight, + get_mesh(), + [dist.Replicate(), dist.Shard(1)], + ) + self.layer = ErnieDecoderLayerAuto(config, layer_idx, ipp) + + Norm = RMSNorm if config.use_rmsnorm else LayerNorm + + if self.layer_idx == self.config.num_hidden_layers - 1: + self.norm = Norm(config, -1) + self.lm_head = ErnieLMHead(config) + + def recompute_training( + self, + layer_module, + hidden_states, + attention_mask, + position_ids, + output_attentions, + past_key_value, + use_cache, + inbatch_pack_offset, + token_type_ids, + ): + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_gate_logits=False) + + return custom_forward + + hidden_states = recompute( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + position_ids, + output_attentions, + past_key_value, + use_cache, + inbatch_pack_offset, + token_type_ids, + use_reentrant=False, + ) + return hidden_states + + def forward(self, args): + output_attentions = self.config.output_attentions + use_cache = self.config.use_cache + output_hidden_states = self.config.output_hidden_states + return_dict = self.config.return_dict + past_key_values = None + past_key_value = None + token_type_ids = None + inbatch_pack_offset = None + if self.embed_tokens is not None: + + input_ids, attention_mask, position_ids = parse_args(args) + if isinstance(input_ids, list): + input_ids, labels = input_ids[:2] + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if input_ids is not None: + batch_size, seq_length = input_ids.shape + else: + raise ValueError( + "You have to specify either input_ids or inputs_embeds" + ) + + seq_length_with_past = seq_length + cache_length = 0 + + if past_key_values is not None: + cache_length = paddle.shape(past_key_values[0])[1] + seq_length_with_past += cache_length + inputs_embeds = self.embed_tokens(input_ids).astype( + self.embed_tokens.weight.dtype + ) + + if self.config.sequence_parallel: + inputs_embeds = paddle.transpose(inputs_embeds, [1, 0, 2]) + global_mesh = global_mesh_starts_with_pp() + + if position_ids is not None: + position_ids = dist.shard_tensor( + position_ids, + global_mesh, + [dist.Replicate() for _ in range(len(global_mesh._shape))], + ) + + can_use_fa = self.config.use_flash_attn and flash_attention is not None + can_mem_eff_attn = ( + self.config.use_mem_eff_attn and inbatch_pack_offset is not None + ) + if can_use_fa or can_mem_eff_attn: + if attention_mask is not None: + attention_mask = None + + elif attention_mask is None: + attention_mask = paddle.ones( + (batch_size, seq_length_with_past), dtype=paddle.bool + ) + if attention_mask is not None: + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, + (batch_size, seq_length), + cache_length, + inputs_embeds.dtype, + ) + attention_mask = dist.shard_tensor( + attention_mask, + global_mesh, + [dist.Replicate() for _ in range(len(global_mesh._shape))], + ) + hidden_states = inputs_embeds + if self.config.tensor_parallel_degree > 1: + hidden_states = dist.reshard( + hidden_states, get_mesh(0), self.placements + ) + + args = return_args(hidden_states, attention_mask, position_ids) + + hidden_states, attention_mask, position_ids = parse_args(args) + + all_hidden_states = () if output_hidden_states else None + + all_router_loss = None + if hasattr(self.config, "use_moe") and self.config.use_moe: + all_router_loss = paddle.to_tensor(0.0) + all_router_loss = dist.shard_tensor( + all_router_loss, get_mesh(0), dist.Replicate() + ) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + has_gradient = not hidden_states.stop_gradient + if position_ids is not None: + position_ids_input = dist.reshard( + position_ids, + get_mesh(self.ipp), + [dist.Replicate(), dist.Replicate()], + ) + else: + position_ids_input = position_ids + attention_mask_input = ( + dist.reshard( + attention_mask, + get_mesh(self.ipp), + [dist.Replicate(), dist.Replicate()], + ) + if attention_mask is not None + else None + ) + token_type_ids_input = ( + dist.reshard( + token_type_ids, + get_mesh(self.ipp), + [dist.Replicate(), dist.Replicate()], + ) + if token_type_ids is not None + else None + ) + if self.config.use_recompute and has_gradient: + layer_outputs = self.recompute_training( + self.layer, + hidden_states, + attention_mask_input, + position_ids_input, + output_attentions, + past_key_value, + use_cache, + inbatch_pack_offset, + token_type_ids_input, + ) + else: + layer_outputs = self.layer( + hidden_states, + attention_mask_input, + position_ids_input, + output_attentions, + past_key_value, + use_cache, + inbatch_pack_offset, + token_type_ids_input, + ) + if isinstance(layer_outputs, (tuple, list)): + hidden_states = layer_outputs[0] + else: + hidden_states = layer_outputs + + ret_args = return_args( + hidden_states, + attention_mask, + position_ids, + ) + if self.norm is not None: + hidden_states = self.norm(hidden_states) + if self.lm_head is not None: + logits = self.lm_head(hidden_states) + ret_args = return_args( + logits, + ) + + return ret_args + + +class ErniePretrainingCriterionPP(ErniePretrainingCriterion): + """ + Criterion for Ernie. + It calculates the final loss. + """ + + def __init__(self, config): + + super().__init__(config) + + def forward(self, prediction_scores, masked_lm_labels, router_loss=None): + """ + calculates the final loss + """ + losses = super().forward(prediction_scores, masked_lm_labels) + if losses is not None: + loss = losses[0] + else: + print("err") + return loss + + +class ErnieForCausalLMAutoPP(ErniePretrainedModelAuto): + """ + ErnieForCausalLMAutoPP is the model class for causal language modeling. + """ + + def __init__(self, config): + """ + Args: + config (Config): Config object containing hyperparameters and other configuration details. + + Returns: + None. + + Initializes the ErnieDecoder with the given config. + """ + super().__init__(config) + + if config.sequence_parallel: + logger.info(f"using sequence_parallel, input seqlen={config.seqlen}") + if config.using_dynamic_sequence_length: + assert ( + not config.micro_batch_size + ), "sequence-parallel needs micro_batch_size setting when using dygramic_sequnence_length" + else: + assert config.seqlen is not None + + assert ( + config.tensor_parallel_degree > 1 + ), f"sequence-parallel needs mp>1, got mp={config.tensor_parallel_degree}" + + new_initializer_range = math.sqrt(0.3333 / config.hidden_size) + logger.info( + f"change initializer-range from {config.initializer_range} to {new_initializer_range}" + ) + config.initializer_range = new_initializer_range + self.config = config + self.criterion = ErniePretrainingCriterionPP(config) + + if self.config.use_rmsnorm: + if self.config.fuse_rms_norm: + logger.info("Use fusedRMSNorm") + else: + logger.info("Use normal RMSNorm") + else: + logger.info("Use normal LayerNorm") + + decoder_layers = [] + + def get_pp_stage_id(layer_id): + pp_degree = global_mesh_starts_with_pp().shape[0] + chunk_size = self.config.num_hidden_layers // ( + pp_degree * self.config.virtual_pp_degree + ) + chunk_id = layer_id // chunk_size + pp_stage_id = chunk_id % pp_degree + return pp_stage_id + + for i in range(config.num_hidden_layers): + pp_stage_id = get_pp_stage_id(i) + decoder_layers.append(ErnieDecoderLayerAutoPP(config, i, pp_stage_id)) + self.layers = nn.LayerList(decoder_layers) + + def forward( + self, + input_ids, + labels=None, + position_ids=None, + attention_mask=None, + inputs_embeds=None, + use_cache=False, + past_key_values=None, + output_attentions=None, + output_hidden_states=None, + return_dict=False, + ignored_index=0, + inbatch_pack_offset=None, + token_type_ids=None, + ): + outputs = return_args(input_ids, attention_mask, position_ids) + + for layer in self.layers: + outputs = layer(outputs) + + return outputs[0] diff --git a/examples/pre-training/models/moe/moe_layer_auto.py b/examples/pre-training/models/moe/moe_layer_auto.py new file mode 100644 index 00000000..1e7e8136 --- /dev/null +++ b/examples/pre-training/models/moe/moe_layer_auto.py @@ -0,0 +1,1198 @@ +# !/usr/bin/env python3 + +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Tuple, List, Optional +import logging +import inspect +from collections import namedtuple +from contextlib import contextmanager +from functools import partial + +import paddle +from paddle import framework +from paddle import nn +from paddle.distributed.communication import stream +import paddle.nn.functional as F +from paddle.distributed import in_auto_parallel_align_mode + +from paddle.autograd import PyLayer +from paddle.distributed.communication.group import Group +from paddle.distributed.fleet.utils import recompute +from paddle.distributed import fleet + +import paddle.distributed as dist +from paddle import Tensor +from paddleformers.trainer.plugins.timer import get_timers +from models.moe.top2_gate_auto import TopKGateFused, TopKGateFusedAuto +from models.sequence_parallel_utils_auto import ScatterOp +from models.utils_auto import ( + manual_backward, +) + + +from models.moe.moe_utils_auto import get_flatten_mesh, get_mesh, _reshard +from paddle.incubate.nn.functional import ( + moe_combine, +) + + + + +logger = logging.getLogger(__name__) + + + + + + +@contextmanager +def profile(name): + """doc""" + if get_timers() is not None: + get_timers()(name).start() + yield + if get_timers() is not None: + get_timers()(name).stop() + + +GateOutput = namedtuple( + "GateOutput", + [ + "aux", + "z", + "logits", + ], +) + + +class GateCombine(PyLayer): + + @staticmethod + def forward(ctx, x, combine_weights, scatter_index): + ctx.x = x + ctx.combine_weights = combine_weights + ctx.scatter_index = scatter_index + assert moe_combine is not None + ret = paddle.incubate.nn.functional.moe_combine(x, combine_weights, scatter_index) + return ret + + @staticmethod + def backward(ctx, grad_y, *_): + assert moe_combine is not None + grad_x, grad_combine_weight_helper = moe_combine.moe_combine_bwd( + ctx.x, ctx.combine_weights, ctx.scatter_index, grad_y + ) + + grad_combine_weight = grad_combine_weight_helper.sum(-1) + return grad_x, grad_combine_weight.reshape(ctx.combine_weights.shape), None + + +def combining_fused(x, combine_weights, scatter_index, hard_gate=False): + + if hard_gate: + x_gatherd = F.embedding(scatter_index, x) + return x_gatherd.squeeze(-2) + ret = GateCombine.apply(x, combine_weights, scatter_index) + ret.stop_gradient = False + return ret + + +def dispatching(x, dispatch_mask, scatter_index, num_experts, capacity): + + output = None + orig_dtype = x.dtype + scatter_index = scatter_index.unbind(1) + dispatch_mask = dispatch_mask.unbind(1) + for i_scatter_index, i_dispatch_mask in zip(scatter_index, dispatch_mask): + init_output = paddle.zeros( + [num_experts * capacity, x.shape[-1]], dtype="float32" + ) + updates = x * i_dispatch_mask.unsqueeze(-1).cast(x.dtype) + if output is None: + output = paddle.scatter( + init_output, + i_scatter_index, + updates, + overwrite=False, + ) + else: + output = output + paddle.scatter( + init_output, + i_scatter_index, + updates, + overwrite=False, + ) + if output.dtype != orig_dtype: + output = output.cast(orig_dtype) + return output + + +def combining(x, combine_weights, scatter_index): + + dim = x.shape[-1] + scatter_index = scatter_index.reshape([-1]) + num_k = combine_weights.shape[-1] + combine_weights = combine_weights.unsqueeze(1) + x = paddle.gather(x, scatter_index).reshape([-1, num_k, dim]) + return paddle.matmul(combine_weights, x).squeeze(1) + + + +class MOELayer(nn.Layer): + + def __init__( + self, + gate: nn.Layer, + experts: List[nn.Layer], + layer_idx, + shared_experts: Optional[List[nn.Layer]] = None, + group: Group = None, + recompute=False, + k=2, + enable_bpr: bool = False, + all_to_all_dropout=0, + group_experts=False, + moe_statics=None, + ): + + super().__init__() + self.gate = gate + self.layer_idx = layer_idx + self.recompute = recompute + logger.info(f"using moe recompute={recompute}") + for p in self.gate.parameters(): + p.is_gate = True + if isinstance(experts, nn.LayerList): + self.experts = experts + else: + logger.info(f"using fused experts, type={type(experts)}") + self.experts = experts + self.shared_experts = shared_experts + + self.group = group + self.k = k + self.all_to_all_dropout = all_to_all_dropout + self.use_correction_bias = moe_statics is not None + self.moe_statics = moe_statics + if self.use_correction_bias: + logger.info( + f"using correction bias, aux-coef:{self.gate.config.moe_aux_loss_lambda}" + ) + assert self.gate.config.moe_use_aux_free + + self.is_mp_moe = ( + hasattr(fleet.fleet, "_hcg") + and group is fleet.get_hybrid_communicate_group().get_model_parallel_group() + ) + self.is_ep_moe = ( + hasattr(fleet.fleet, "_hcg") + and hasattr( + fleet.get_hybrid_communicate_group(), + "get_moe_sharding_parallel_world_size", + ) + and fleet.get_hybrid_communicate_group().get_moe_sharding_parallel_world_size() + > 0 + ) + is_dummy_moe = dist.get_world_size(group) == 1 + + for p in experts.parameters(): + p.expert = not (self.is_mp_moe or is_dummy_moe) + p.no_sync = not (self.is_mp_moe or is_dummy_moe) + logger.info(f"expert no-sync={p.no_sync}-{p.name}") + if self.is_mp_moe or self.is_ep_moe: + p.is_distributed = True + + expert_color = None + if self.is_ep_moe: + moe_grad_group = ( + fleet.get_hybrid_communicate_group().get_moe_sharding_parallel_group() + ) + expert_color = {"color": "moe_expert", "group": moe_grad_group} + elif ( + self.config.offline_quant_expert_weight + and self.config.clear_origin_weight_when_offline_quant + ): + expert_color = {"color": "moe_expert"} + + if expert_color is not None: + for p in self.experts.parameters(): + setattr(p, "color", expert_color) + + self.world_size = dist.get_world_size(self.group) + self.rank = dist.get_rank(self.group) + if self.world_size < 1: + self.world_size = 1 + if self.rank < 0: + self.rank = 0 + + self.num_local_experts = len(self.experts) + self.dispatch_by_task = ( + hasattr(self.gate, "dispatch_by_task") and self.gate.dispatch_by_task + ) + + if self.dispatch_by_task: + assert 0, "no supported, checkout earylier code" + assert self.num_local_experts == 1 + + self.input_preprocess = self.output_postprocess = None + self.group_experts = group_experts + self.config = self.gate.config + self.zero = paddle.to_tensor(0, dtype=paddle.float32) + + self._rr_moe_gate_dispatch = None + self._rr_moe_combine = None + self.use_norm_gate_recompute = None + + if self.config.use_recompute and self.config.skip_recompute_ops.get( + "moe_gate_dispatch", False + ): + self._rr_moe_gate_dispatch = None + if self.config.use_recompute and self.config.skip_recompute_ops.get( + "moe_combine", False + ): + self._rr_moe_combine = None + if hasattr(fleet.fleet, "_hcg"): + hcg = fleet.get_hybrid_communicate_group() + if ( + hasattr(hcg, "get_moe_sharding_parallel_world_size") + and hcg.get_moe_sharding_parallel_world_size() > 0 + ): + moe_grad_group = hcg.get_moe_sharding_parallel_group() + for p in self.experts.parameters(): + setattr( + p, "color", {"color": "moe_expert", "group": moe_grad_group} + ) + + def forward_experts(self, dispatched_input): + + with profile("fwd-expert"): + dispatched_input = dispatched_input.reshape( + [ + self.world_size, + self.num_local_experts, + -1, + dispatched_input.shape[-1], + ] + ) + expert_outputs = [] + if isinstance(self.experts, nn.LayerList): + + chunks = dispatched_input.transpose([1, 0, 2, 3]).contiguous().unbind(0) + assert len(chunks) == len(self.experts), ( + len(chunks), + len(self.experts), + ) + for chunk, expert in zip(chunks, self.experts): + expert_outputs += [expert(chunk)] + + expert_output = paddle.stack(expert_outputs, axis=1) + + else: + dispatched_input = dispatched_input.transpose([1, 0, 2, 3]) + dispatched_input.contiguous() + orig_shape = dispatched_input.shape + chunks = dispatched_input.reshape([orig_shape[0], -1, orig_shape[-1]]) + chunks = self.experts(chunks) + chunks = chunks.reshape(orig_shape[:-1] + [chunks.shape[-1]]).unbind(0) + expert_outputs += chunks + expert_output = paddle.stack(expert_outputs, axis=1) + return expert_output + + def fused_gate_logits_process( + self, gate_logits, token_type_ids, offload_helper=None + ): + + k = self.k + experts_type_ids = self.gate.experts_type_ids + use_hard_gate = self.config.moe_use_hard_gate + max_prob = None + + if token_type_ids is not None and use_hard_gate: + if offload_helper is None: + offload_helper = dict() + lm_mask = token_type_ids == 0 + is_lm = lm_mask.any() + mm_mask = token_type_ids == 1 + is_mm = mm_mask.any() + seq_lm = lm_mask.sum() + seq_mm = mm_mask.sum() + lm_mask = lm_mask.unsqueeze(1) & (experts_type_ids == 0).unsqueeze(0) + mm_mask = mm_mask.unsqueeze(1) & (experts_type_ids == 1).unsqueeze(0) + offload_helper["lm_mask"] = [lm_mask, is_lm, seq_lm] + offload_helper["mm_mask"] = [mm_mask, is_mm, seq_mm] + + is_lm = offload_helper["lm_mask"][1] + prob = paddle.zeros_like(gate_logits) + if is_lm: + lm_mask = offload_helper["lm_mask"][0] + seq_lm_cpu = offload_helper["lm_mask"][2] + lm_mask_nonzero = lm_mask.nonzero() + lm_partial_gate_logits = gate_logits.gather_nd(lm_mask_nonzero).reshape( + [seq_lm_cpu, -1] + ) + if self.group_experts: + lm_prob = self.gate.act( + lm_partial_gate_logits.reshape( + [lm_partial_gate_logits.shape[0], k, -1] + ) + ) + max_prob = lm_prob.max(-1, keepdim=True) + lm_prob /= max_prob + else: + lm_prob = self.gate.act(lm_partial_gate_logits) + prob = paddle.scatter_nd_add(prob, lm_mask_nonzero, lm_prob.flatten()) + is_mm = offload_helper["mm_mask"][1] + if is_mm: + mm_mask = offload_helper["mm_mask"][0] + seq_mm_cpu = offload_helper["mm_mask"][2] + mm_mask_nonzero = paddle.nonzero(mm_mask) + mm_partial_gate_logits = gate_logits.gather_nd(mm_mask_nonzero).reshape( + [seq_mm_cpu, -1] + ) + mm_prob = self.gate.act(mm_partial_gate_logits) + prob = paddle.scatter_nd_add(prob, mm_mask_nonzero, mm_prob.flatten()) + else: + if self.group_experts: + prob = self.gate.act(gate_logits.reshape([gate_logits.shape[0], k, -1])) + max_prob = prob.max(-1, keepdim=True) + prob /= max_prob + prob = prob.reshape([prob.shape[0], -1]) + else: + prob = self.gate.act(gate_logits) + return prob, max_prob + + def gate_and_distpach(self, input, token_type_ids): + + seqlen, d_model = input.shape + args = () + if token_type_ids is not None: + token_type_ids = token_type_ids.reshape([-1]) + args = (token_type_ids,) + + use_fuse = isinstance(self.gate, (TopKGateFused)) + if use_fuse: + if self.use_norm_gate_recompute: + ( + gate_logits, + capacity, + router_loss, + norm_res, + ) = self.fused_norm_gate(input) + input = norm_res + else: + ( + gate_logits, + capacity, + router_loss, + ) = self.gate(input, *args) + else: + ( + capacity, + dispatch_mask, + combine_weights, + scatter_index, + router_loss, + gate_logits, + ) = self.gate( + input, + *args, + correction_bias=( + self.moe_statics.e_score_correction_bias[0] + if self.use_correction_bias + else None + ), + ) + prob = None + if self.config.moe_multimodal_paired_experts: + assert token_type_ids is not None + input = paddle.concat( + [input, token_type_ids.unsqueeze(-1).astype(input.dtype)], axis=-1 + ) + if self.input_preprocess is not None: + input, gate_logits = self.input_preprocess(input, gate_logits, capacity) + if use_fuse: + k = self.k + prob, max_prob = self.fused_gate_logits_process(gate_logits, token_type_ids) + + assert moe_ops is not None + with profile("dispatch_op"): + if ( + "corr_bias" + in inspect.signature(moe_ops.moe_gate_dispatch).parameters + ): + if self.use_correction_bias: + compat_args = (self.moe_statics.e_score_correction_bias[0],) + else: + compat_args = (None,) + else: + assert ( + not self.use_correction_bias + ), "correction bias not supported, rebuild moe-ops" + compat_args = () + if not self.config.use_ep_comm_overlap: + if self._rr_moe_gate_dispatch is None: + ( + dispatched_input, + combine_weights_unnorm, + scatter_index, + dispatch_mask, + _, + ) = moe_ops.moe_gate_dispatch( + input, + prob, + *compat_args, + k=k, + capacity=capacity, + use_pad=True, + ) + else: + ( + dispatched_input, + combine_weights_unnorm, + scatter_index, + dispatch_mask, + _, + ) = self._rr_moe_gate_dispatch( + input, + prob, + compat_args, + k=k, + capacity=capacity, + use_pad=True, + ) + else: + ( + dispatched_input, + combine_weights_unnorm, + scatter_index, + dispatch_mask, + _, + ) = moe_ops.moe_gate_dispatch_permute( + input, + prob, + *compat_args, + k=k, + capacity=capacity, + world_size=self.group.nranks, + ) + dispatch_mask = paddle.diff(F.pad(dispatch_mask, (1, 0))) + if self.use_correction_bias and framework._dygraph_tracer()._has_grad: + if self.gate.config.multimodel_experts: + for i in range(len(self.moe_statics.expert_usage)): + self.moe_statics.expert_usage[i] += dispatch_mask[ + self.gate.experts_type_mask[i] + ].detach() + else: + self.moe_statics.expert_usage[0] += dispatch_mask.detach() + dispatched_input.stop_gradient = False + combine_weights_unnorm.stop_gradient = False + scatter_index.stop_gradient = False + dispatch_mask.stop_gradient = True + + scatter_index = scatter_index.transpose([1, 0]) + if self.group_experts: + if max_prob is not None: + if token_type_ids is not None: + p = paddle.ones_like(combine_weights_unnorm.unsqueeze(-1)) + p = paddle.scatter_nd_add( + p, paddle.nonzero(token_type_ids == 0), -1 + max_prob + ) + else: + p = max_prob + combine_weights_unnorm = ( + combine_weights_unnorm.unsqueeze(-1) * p + ).squeeze(-1) + prob = (prob.reshape([p.shape[0], k, -1]) * p).reshape( + [p.shape[0], -1] + ) + if self.gate.norm_gate_logits: + combine_weights = combine_weights_unnorm / paddle.clip( + combine_weights_unnorm.sum(-1, keepdim=True), min=1e-12 + ) + else: + combine_weights = combine_weights_unnorm + combine_weights = combine_weights.cast(dispatched_input.dtype) + else: + dispatched_input = dispatching( + input, + dispatch_mask, + scatter_index, + num_experts=self.world_size * self.num_local_experts, + capacity=capacity, + ) + if self.use_correction_bias and framework._dygraph_tracer()._has_grad: + usage = paddle.bincount( + scatter_index.reshape([-1]) // capacity, + minlength=self.world_size * self.num_local_experts, + ) + assert ( + not self.config.multimodel_experts + ), "correction bias not supported, use top2-fused gate" + self.moe_statics.expert_usage[0] += usage.detach() + if not self.config.use_ep_comm_overlap: + dispatched_input = dispatched_input.reshape( + [ + self.world_size * self.num_local_experts, + capacity, + ( + d_model + if not self.config.moe_multimodal_paired_experts + else d_model + 1 + ), + ] + ) + else: + assert ( + len(dispatched_input.shape) == 4 + and dispatched_input.shape[1] == self.world_size + and dispatched_input.shape[0] == self.num_local_experts + ), ( + f"When using ep_comm_overlap, moe_gate_dispatch_permute is needed. " + f"Expected dispatched_input to have shape[1] == {self.world_size} " + f"and shape[0] == {self.num_local_experts}, " + f"but got shape {dispatched_input.shape}" + ) + dispatched_input = dispatched_input + dispatch_mask.stop_gradient = True + scatter_index.stop_gradient = False + return ( + dispatched_input, + combine_weights, + dispatch_mask, + scatter_index, + router_loss, + gate_logits, + prob, + ) + + def _calc_router_loss( + self, + dispatch_mask, + gate_logits, + gate_prob, + num_experts, + use_group, + layer_idx, + token_type=None, + tokens_type_mask=None, + dispatch_tokens_mask=None, + prefix="", + ): + router_loss, l_aux, orthogonal_loss, zloss = 0.0, None, None, None + if self.gate.config.moe_aux_loss_lambda: + l_aux = self.gate._cal_aux_loss( + gate_prob, + dispatch_mask, + num_experts, + use_group, + tokens_type_mask, + dispatch_tokens_mask, + ) + router_loss += self.gate.moe_aux_loss_lambda[token_type or 0] * l_aux + else: + router_loss += self.zero * gate_prob[0, 0] + if self.gate.config.moe_orthogonal_loss_lambda: + orthogonal_loss = self.gate._cal_orthogonal_loss(token_type, use_group) + router_loss += ( + self.gate.moe_orthogonal_loss_lambda[token_type or 0] * orthogonal_loss + ) + + + return router_loss + + def calc_router_loss_and_logging( + self, + router_loss, + combine_weights, + dispatch_mask, + gate_logits, + gate_prob, + token_type_ids, + dispatch_token_type_ids=None, + offload_helper=None, + ): + + use_fuse = isinstance(self.gate, (TopKGateFused)) + if use_fuse: + assert gate_prob is not None + if token_type_ids is not None and self.gate.config.moe_use_hard_gate: + if not self.gate.weight.stop_gradient: + lm_tokens_mask = token_type_ids == 0 + if offload_helper is not None: + is_lm = offload_helper["lm_mask"][1] + else: + is_lm = lm_tokens_mask.any() + if is_lm: + dispatch_tokens_mask = ( + dispatch_token_type_ids == 0 + if dispatch_token_type_ids is not None + else None + ) + router_loss += self._calc_router_loss( + ( + dispatch_mask[self.gate.experts_type_mask[0]] + if hasattr(self.gate, "experts_type_mask") + else dispatch_mask + ), + ( + gate_logits[:, self.gate.experts_type_mask[0]] + if hasattr(self.gate, "experts_type_mask") + else gate_logits + ), + ( + gate_prob[:, self.gate.experts_type_mask[0]] + if hasattr(self.gate, "experts_type_mask") + else gate_prob + ), + ( + self.gate.num_experts_list[0] + if hasattr(self.gate, "num_experts_list") + else self.gate.num_experts_tensor + ), + self.group_experts, + self.layer_idx, + 0, + lm_tokens_mask, + dispatch_tokens_mask, + prefix="lm", + ) + mm_tokens_mask = token_type_ids == 1 + if offload_helper is not None: + is_mm = offload_helper["mm_mask"][1] + else: + is_mm = mm_tokens_mask.any() + if is_mm: + dispatch_tokens_mask = ( + dispatch_token_type_ids == 1 + if dispatch_token_type_ids is not None + else None + ) + router_loss += self._calc_router_loss( + dispatch_mask[self.gate.experts_type_mask[1]], + gate_logits[:, self.gate.experts_type_mask[1]], + gate_prob[:, self.gate.experts_type_mask[1]], + self.gate.num_experts_list[1], + False, + self.layer_idx, + 1, + mm_tokens_mask, + dispatch_tokens_mask, + prefix="mm", + ) + + else: + router_loss += self._calc_router_loss( + dispatch_mask, + gate_logits, + gate_prob, + self.gate.num_experts_tensor, + self.group_experts, + self.layer_idx, + ) + + return router_loss + + def combine_expert_output(self, expert_output, combine_weights, scatter_index): + + expert_output = expert_output.reshape([-1, expert_output.shape[-1]]) + use_fuse = isinstance(self.gate, (TopKGateFused)) + combine_fn = combining_fused if use_fuse else combining + combined_output = combine_fn(expert_output, combine_weights, scatter_index) + + if self.output_postprocess is not None: + combined_output = self.output_postprocess(combined_output) + return combined_output + + def forward_single_stage(self, dispatched_input, stage_id): + assert isinstance(self.experts, nn.LayerList) + return self.experts[stage_id](dispatched_input) + + def forward( + self, + input: Tensor, + token_type_ids=None, + ) : + pass + +def combining_fused_auto(x, combine_weights, scatter_index, hard_gate=False): + """ + Args: + x: Tensor[seq, dim] + combine_weights: [s, k] + scatter_index: ** [k, s] ** + + Returns: + y: Tensor[s, dim] + """ + if hard_gate: + x_gatherd = F.embedding(scatter_index, x) + return x_gatherd.squeeze(-2) + ret = paddle.incubate.nn.functional.moe_combine(x, combine_weights, scatter_index) + + ret.stop_gradient = False + return ret + + +def detach_and_requires_grad_(*args): + """detach_and_requires_grad_""" + ret = [a.detach() if a is not None else None for a in args] + for r, a in zip(ret, args): + if a is not None: + r.stop_gradient = a.stop_gradient + return ret + + +def bpr_preprocess(input, logits, capacity, buffer): + """impletment bpr sorting""" + assert input.ndim == 2, input.shape + idx = paddle.argsort(logits.max(-1), axis=0, descending=True) + input = input[idx] + logits = logits[idx] + buffer["idx"] = idx + return input, logits + + +def bpr_postprocess(output, buffer): + """bpr sorting""" + idx = buffer.pop("idx") + rev_idx = paddle.argsort(idx) + output = output[rev_idx] + return output + + +class MOELayerAuto(MOELayer): + + def __init__( + self, + gate: nn.Layer, + experts: List[nn.Layer], + layer_idx, + shared_experts: Optional[List[nn.Layer]] = None, + group: Group = None, + recompute=False, + k=2, + enable_pbr: bool = False, + all_to_all_dropout=0, + group_experts=False, + config=None, + ipp=0, + ): + nn.Layer.__init__(self) + self.config = config + self.gate = gate + self.layer_idx = layer_idx + self.ipp = ipp + self.recompute = recompute + logger.info(f"using moe recompute={recompute}") + for p in self.gate.parameters(): + p.is_gate = True + if isinstance(experts, nn.LayerList): + self.experts = experts + else: + logger.info(f"using fused experts, type={type(experts)}") + self.experts = experts + self.shared_experts = shared_experts + + self.group = group + self.k = k + self.all_to_all_dropout = all_to_all_dropout + is_mp_moe = ( + hasattr(fleet.fleet, "_hcg") + and group is fleet.get_hybrid_communicate_group().get_model_parallel_group() + ) + is_dummy_moe = config.moe_world_size == 1 + + for p in experts.parameters(): + p.expert = not (is_mp_moe or is_dummy_moe) + p.no_sync = not (is_mp_moe or is_dummy_moe) + logger.info(f"expert no-sync={p.no_sync}-{p.name}") + if is_mp_moe or is_mp_moe: + p.is_distributed = True + + self.world_size = config.moe_world_size + if self.group in fleet.auto.get_mesh().dim_names: + self.rank = fleet.auto.get_mesh().get_rank_by_dim_and_process_id( + self.group, dist.get_rank() + ) + if self.rank < 0: + self.rank = 0 + else: + self.rank = 0 + + self.num_experts_per_group = len(self.experts) + self.ep_group_num = config.moe_world_size + self.num_local_experts = self.num_experts_per_group // self.ep_group_num + + self.moe_mesh_dim = 0 if config.moe_group == "dp" else 1 + self.dispatch_by_task = ( + hasattr(self.gate, "dispatch_by_task") and self.gate.dispatch_by_task + ) + + if self.dispatch_by_task: + assert 0, "no supported, checkout earylier code" + assert self.num_local_experts == 1 + + if enable_pbr: + logger.info("using BPR") + prepost_process_buffer = {} + self.input_preprocess = partial( + bpr_preprocess, buffer=prepost_process_buffer + ) + self.output_postprocess = partial( + bpr_postprocess, buffer=prepost_process_buffer + ) + else: + self.input_preprocess = self.output_postprocess = None + self.group_experts = group_experts + + def _cal_multimodel_experts_prob( + self, gate_logits, token_type_ids, group_experts, moe_k + ): + + if not self.gate.experts_type_ids.is_dist(): + self.gate.experts_type_ids = dist.shard_tensor( + self.gate.experts_type_ids, + get_mesh(), + [dist.Replicate(), dist.Replicate()], + ) + return super()._cal_multimodel_experts_prob( + gate_logits, token_type_ids, group_experts, moe_k + ) + + def forward_experts(self, dispatched_input): + """ + call experts sequently + Args: + dispatched_input: Tensor[num_experts, capacity, dim] + Returns: + expert_output: Tensor[num_experts, capacity, dim] + """ + assert isinstance(self.experts, nn.LayerList) + if self.config.moe_group == "mp": + local_input_list = dist.auto_parallel.api.moe_sub_mesh_tensors( + dispatched_input, + get_mesh(self.ipp), + self.moe_mesh_dim, + [dist.Shard(2), dist.Shard(0)], + ) + + assert len(self.experts) % len(local_input_list) == 0, ( + "num of experts must be divided by num of ep_group, " + f"but got {len(self.experts)} and {len(local_input_list)}" + ) + expert_group_outputs = [] + for i_ep_group, local_input in enumerate(local_input_list): + chunks = local_input.unbind(1) + experts = self.experts[ + i_ep_group + * self.num_local_experts : (i_ep_group + 1) + * self.num_local_experts + ] + ep_output = [] + assert len(experts) == len( + chunks + ), f"num of experts must be equal to num of chunks, but got {len(experts)} and {len(chunks)}" + for chunk_id, (chunk, expert) in enumerate(zip(chunks, experts)): + ep_output += [expert(chunk)] + expert_group_outputs += [paddle.stack(ep_output, axis=1)] + return expert_group_outputs + else: + chunks = dispatched_input.unbind(1) + expert_outputs = [] + assert len(chunks) == len(self.experts), (len(chunks), len(self.experts)) + for chunk, expert in zip(chunks, self.experts): + expert_outputs += [expert(chunk)] + expert_output = paddle.stack(expert_outputs, axis=1) + return expert_output + + def gate_and_distpach(self, input, token_type_ids): + """ + calc gate and dispatch inputs (and do logging, optionaly) + Args: + input: Tensor[seq, dim], float + token_type_ids: Tensor[seq], int + Returns: + dispatched_input: Tensor[num_experts, capacity, dim] + combine_weights: [seq, k] + scatter_index: [seq, k] + router_loss: scalar + gate_logits: [seq, num_experts] + """ + with profile("moe-gate"): + args = () + if token_type_ids is not None: + token_type_ids = token_type_ids.reshape([-1]) + args = (token_type_ids,) + use_fuse = isinstance(self.gate, (TopKGateFusedAuto)) + if use_fuse: + (gate_logits, capacity, router_loss, local_capacity) = self.gate( + input, *args + ) + else: + ( + capacity, + dispatch_mask, + combine_weights, + scatter_index, + router_loss, + gate_logits, + ) = self.gate(input, *args) + prob = None + if self.input_preprocess is not None: + input, gate_logits = self.input_preprocess(input, gate_logits, capacity) + + with profile("moe-dispatch"): + if use_fuse: + k = self.k + prob, max_prob = self.fused_gate_logits_process( + gate_logits, token_type_ids + ) + ( + dispatched_input, + combine_weights_unnorm, + scatter_index, + dispatch_mask, + _, + ) = paddle.incubate.nn.functional.moe_gate_dispatch(input, prob, None, k, local_capacity, True) + dispatched_input.stop_gradient = False + combine_weights_unnorm.stop_gradient = False + dispatch_mask.stop_gradient = True + + scatter_index = scatter_index.transpose([1, 0]) + + if self.group_experts: + if max_prob is not None: + if token_type_ids is not None: + p = paddle.ones_like(combine_weights_unnorm.unsqueeze(-1)) + p = paddle.scatter_nd_add( + p, paddle.nonzero(token_type_ids == 0), -1 + max_prob + ) + else: + p = max_prob + combine_weights_unnorm = ( + combine_weights_unnorm.unsqueeze(-1) * p + ).squeeze(-1) + prob = (prob.reshape([p.shape[0], k, -1]) * p).reshape( + [p.shape[0], -1] + ) + combine_weights = combine_weights_unnorm / paddle.clip( + combine_weights_unnorm.sum(-1, keepdim=True), min=1e-12 + ) + combine_weights = combine_weights.cast(dispatched_input.dtype) + else: + dispatched_input = dispatching( + input, + dispatch_mask, + scatter_index, + num_experts=self.config.moe_num_experts, + capacity=capacity, + ) + dispatch_mask.stop_gradient = True + scatter_index.stop_gradient = False + return ( + dispatched_input, + combine_weights, + dispatch_mask, + scatter_index, + router_loss, + gate_logits, + prob, + ) + + def combine_expert_output(self, expert_output, combine_weights, scatter_index): + """ + Combine Expert output + Args: + expert_output: Tensor[num_experts, caapcity, dim] + combine_weights: + Returns: + combined_output: Tensor[seqlen, dim] + """ + with profile("moe-combine"): + if self.config.moe_use_all2all and self.config.moe_group == "mp": + expert_output = dist.auto_parallel.moe_utils._dist_reshape( + expert_output, + [-1, expert_output.shape[-1]], + get_flatten_mesh(get_mesh(self.ipp)), + [dist.Shard(0)], + ) + else: + expert_output = expert_output.reshape([-1, expert_output.shape[-1]]) + + if not self.config.moe_use_all2all: + if self.config.moe_group == "mp": + expert_output = dist.reshard( + expert_output, + get_mesh(self.ipp), + [dist.Replicate(), dist.Replicate()], + ) + else: + expert_output = dist.reshard( + expert_output, get_mesh(), [dist.Shard(0), dist.Replicate()] + ) + use_fuse = isinstance(self.gate, (TopKGateFusedAuto)) + combine_fn = combining_fused_auto if use_fuse else combining + combined_output = combine_fn(expert_output, combine_weights, scatter_index) + + if self.output_postprocess is not None: + combined_output = self.output_postprocess(combined_output) + return combined_output + + def forward( + self, + input: Tensor, + token_type_ids=None, + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """ + Args: + input (`Tensor`): The input data with shape ``(s, d)``. + Only one token is supported for now. + token_type_ids (`Tensor`) int64 tensor with shape (s), + if specified, rount tensor according to `token_type_ids`. + Returns: + output (`Tensor`): The final output tensor with shape ``(s, d)`` where ``m`` is the + size of model parameters. + combine_weights (`Tensor`, optional): A tensor with shape ``(s,)``, which represents weights + for each expert in MoE. + router_loss (`Tensor`, optional): A scalar tensor representing the loss of routing function. + """ + if self.shared_experts is not None: + shared_expert_input = dist.reshard( + input, + get_mesh(self.ipp), + [dist.Shard(1), dist.Replicate()], + ) + if input.ndim == 3: + orig_shape = input.shape + input = dist.reshard( + input, get_mesh(self.ipp), [dist.Replicate(), dist.Shard(0)] + ) + if self.config.moe_use_all2all: + input = dist.auto_parallel.moe_utils._dist_reshape( + input, + [-1, input.shape[-1]], + get_flatten_mesh(get_mesh(self.ipp)), + [dist.Shard(0)], + ) + else: + input = input.reshape([-1, input.shape[-1]]) + else: + orig_shape = None + assert ( + len(input.shape) == 2 + ), f"input Tensor must have dimensions: (s)equence, (d)im, got:{input.shape}" + seqlen, d_model = input.shape + + if token_type_ids is not None: + token_type_ids = token_type_ids.clone()[:, :-1] + if self.config.sequence_parallel: + token_type_ids = token_type_ids.reshape([-1]) + token_type_ids.stop_gradient = True + + assert self.gate is not None + if hasattr(self, "rng") and self.rng.random() < self.all_to_all_dropout: + orig_shape_2 = input.shape + output = self.forward_experts(input) + output += self.gate.weight.sum() * 0.0 + output = output.reshape(orig_shape or orig_shape_2) + return output, None, 0 + + ( + dispatched_input, + combine_weights, + dispatch_mask, + scatter_index, + router_loss, + gate_logits, + gate_prob, + ) = self.gate_and_distpach(input, token_type_ids) + if self.config.moe_use_all2all and self.config.moe_group == "mp": + dispatched_input = _reshard( + dispatched_input, get_mesh(self.ipp), [dist.Shard(1), dist.Shard(1)] + ) + if self.config.moe_group == "mp": + dispatched_input = dist.reshard( + dispatched_input, get_mesh(self.ipp), [dist.Shard(1), dist.Shard(0)] + ) + + if self.shared_experts is not None: + shared_out = self.shared_experts(shared_expert_input) + dispatched_input = dispatched_input.reshape( + [self.config.moe_world_size, self.num_local_experts, -1, d_model] + ) + expert_out = self.forward_experts(dispatched_input) + if self.config.moe_group == "mp": + expert_out = dist.auto_parallel.api.moe_global_mesh_tensor( + expert_out, + get_mesh(self.ipp), + [dist.Shard(2), dist.Shard(0)], + self.moe_mesh_dim, + ) + expert_out = dist.auto_parallel.moe_utils._dist_reshape( + expert_out, + [self.config.moe_world_size * self.num_local_experts, -1, d_model], + get_mesh(self.ipp), + [dist.Shard(1), dist.Shard(0)], + ) + expert_out = dist.reshard( + expert_out, get_mesh(self.ipp), [dist.Shard(1), dist.Shard(1)] + ) + if not in_auto_parallel_align_mode(): + router_loss2 = self.calc_router_loss_and_logging( + router_loss, + combine_weights, + dispatch_mask, + gate_logits, + gate_prob, + token_type_ids, + ) + else: + router_loss2 = router_loss + router_loss2 = dist.shard_tensor( + router_loss2, get_flatten_mesh(get_mesh(self.ipp)), [dist.Replicate()] + ) + combined_output = self.combine_expert_output( + expert_out, combine_weights, scatter_index + ) + + if self.shared_experts is not None: + shared_out = dist.auto_parallel.moe_utils._dist_reshape( + shared_out, + [-1, shared_out.shape[-1]], + get_flatten_mesh(get_mesh(self.ipp)), + [dist.Shard(0)], + ) + combined_output += shared_out + + if orig_shape: + if self.config.moe_use_all2all: + combined_output = dist.auto_parallel.moe_utils._dist_reshape( + combined_output, + orig_shape[:-1] + [combined_output.shape[-1]], + get_mesh(self.ipp), + [dist.Shard(1), dist.Shard(0)], + ) + router_loss2 = _reshard( + router_loss2, + get_mesh(self.ipp), + [dist.Replicate(), dist.Replicate()], + ) + else: + combined_output = combined_output.reshape( + orig_shape[:-1] + [combined_output.shape[-1]] + ) + return combined_output, combine_weights, router_loss2, gate_logits diff --git a/examples/pre-training/models/moe/moe_utils_auto.py b/examples/pre-training/models/moe/moe_utils_auto.py new file mode 100644 index 00000000..fbaba34f --- /dev/null +++ b/examples/pre-training/models/moe/moe_utils_auto.py @@ -0,0 +1,40 @@ +# !/usr/bin/env python3 + +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import paddle.distributed as dist +from paddle.distributed import fleet + + +def get_flatten_mesh(mesh): + + return dist.ProcessMesh(mesh.process_ids) + + +def get_mesh(pp_idx=0): + + mesh = fleet.auto.get_mesh() + if "pp" in mesh.dim_names: + mesh = mesh.get_mesh_with_dim("pp", pp_idx) + return mesh + + +def _reshard(tensor, mesh, placements): + + dst_tensor = dist.auto_parallel.moe_utils._dist_reshape( + tensor, tensor.shape, mesh, placements + ) + return dst_tensor diff --git a/examples/pre-training/models/moe/top2_gate_auto.py b/examples/pre-training/models/moe/top2_gate_auto.py new file mode 100644 index 00000000..989329b7 --- /dev/null +++ b/examples/pre-training/models/moe/top2_gate_auto.py @@ -0,0 +1,650 @@ +# !/usr/bin/env python3 + +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple +from functools import partial +import logging +import numpy as np +import paddle +from paddle import Tensor +import paddle.distributed as dist +import paddle.nn.functional as F +from paddle import nn +from paddle.utils import unique_name +from paddle.nn.clip import _squared_l2_norm +from paddle.distributed import fleet +from models.moe.moe_utils_auto import get_mesh, get_flatten_mesh + + +try: + from custom_setup_ops import matmul_bwd +except ImportError: + matmul_bwd = None + + +logger = logging.getLogger(__name__) + + +def cal_aux_loss_func( + gate_prob, + dispatch_mask, + tokens_mask, + dispatch_tokens_mask, + num_experts, + use_group, + moe_k, + global_aux_loss=False, + rank=None, + group=None, +): + if tokens_mask is not None and tokens_mask.dtype != gate_prob.dtype: + tokens_mask = tokens_mask.astype(gate_prob.dtype) + + scale = None + if dispatch_tokens_mask is not None: + seqlen_float = dispatch_tokens_mask.astype(gate_prob.dtype).sum() + if ( + tokens_mask is not None + and gate_prob.shape[0] != dispatch_tokens_mask.shape[0] + ): + scale = seqlen_float / paddle.clip(tokens_mask.sum(), min=1e-6) + elif tokens_mask is not None: + seqlen_float = tokens_mask.sum() + else: + seqlen_float = gate_prob.numel().astype(gate_prob.dtype) / num_experts + seqlen_float = paddle.clip(seqlen_float, min=1e-6) + + if len(dispatch_mask.shape) == 2: + dispatch_mask = dispatch_mask.sum(0) + ce = dispatch_mask.astype(gate_prob.dtype).detach() / seqlen_float + me = paddle.sum(gate_prob, axis=0) / seqlen_float + if global_aux_loss: + me_list, ce_list = [], [] + dist.all_gather(me_list, me, group=group) + dist.all_gather(ce_list, ce, group=group) + + me_list[rank] = me + ce_list[rank] = ce + me = paddle.stack(me_list).mean(0) + ce = paddle.stack(ce_list).mean(0) + + l_aux = paddle.sum(me * ce) * num_experts + if use_group: + l_aux = l_aux / moe_k + + if scale is not None: + l_aux = l_aux + (scale - 1) * l_aux.detach() + + return l_aux + +def masked_fill(x, mask, value): + + y = paddle.full(x.shape, value, x.dtype) + return paddle.where(mask, y, x) + + +@paddle.no_grad() +def compute_optimal_transport(M, r, c, lam=1.0, epsilon=1e-8, max_iters: int = 10): + + n, _ = M.shape + P = F.softmax(-M / lam) + u = paddle.zeros(n, "float32") + for _ in range(max_iters): + if (u - P.sum(1)).abs().max() < epsilon: + break + u = P.sum(1) + P *= (r / (u + 1e-8)).reshape((-1, 1)) + P *= (c / (P.sum(0) + 1e-8)).reshape((1, -1)) + P = paddle.where(~P.isnan(), P, paddle.zeros_like(P)) + return P, _ + + +def cast_if_needed(x, dtype): + + return x.cast(dtype) if x.dtype != dtype else x + + +class FusedGateDetachMatmul(paddle.autograd.PyLayer): + + @staticmethod + def forward(ctx, x, w): + + ctx.dtype = paddle.float32 + ctx.save_for_backward(x, w) + return F.linear(cast_if_needed(x, ctx.dtype), cast_if_needed(w, ctx.dtype)) + + @staticmethod + def backward(ctx, y_grad): + + x, w = ctx.saved_tensor() + assert ctx.dtype == y_grad.dtype, "dtype not match" + x_g, w_g = matmul_bwd( + cast_if_needed(x, ctx.dtype), + cast_if_needed(w, ctx.dtype), + y_grad, + False, + False, + ) + return cast_if_needed(x_g, x.dtype), cast_if_needed(w_g, w.dtype) + + +def gate_detach_matmul(x, weight, use_fuse, use_fake_gate=False): + + if use_fuse: + score = FusedGateDetachMatmul.apply(x, weight) + else: + x = cast_if_needed(x, paddle.float32) + score = F.linear(x, weight) + + if use_fake_gate: + score = paddle.randn(score.shape).astype(score.dtype) + score - score + return score + + +class Top2Gate(nn.Layer): + + def __init__(self, config, layer_idx: int, group, gate_weight=None) -> None: + + super().__init__() + self.config = config + self.fuse_gate_detach_matmul = config.fuse_gate_detach_matmul + if self.fuse_gate_detach_matmul: + assert matmul_bwd is not None, "matmul_bwd is not supported" + + self.use_fake_gate = config.use_fake_gate + if self.use_fake_gate: + logging.warning( + "You are use fake_gate, which is just for test, not for real training." + ) + + self.model_dim = config.hidden_size + self.num_experts = config.moe_num_experts + self.num_experts_tensor = ( + sum(config.moe_num_experts) + if config.multimodel_experts + else config.moe_num_experts + ) + + self.cap = config.moe_capacity + self.group = group + + self.layer_idx = layer_idx + self.global_aux_loss = config.global_aux_loss + if self.global_aux_loss: + self.rank = dist.get_rank(self.group) + + self.sinkhorn_2gate = config.sinkhorn_2gate + self.sinkhorn_temp = config.sinkhorn_temp + self.use_token_type_bias = config.moe_use_token_type_bias + self.use_correction_bias = config.moe_use_aux_free + + if config.moe_gate_act == "softmax": + self.act = partial(F.softmax, axis=-1) + elif config.moe_gate_act == "sigmoid": + self.act = F.sigmoid + else: + raise ValueError(f"{config.moe_gate_act} is not supported.") + self.no_jitter = True + self.expert_drop = False + self.eye_matrix = None + self.eye_matrix_size = None + self.norm_gate_logits = config.moe_norm_gate_logits + self.one = paddle.ones([], dtype="float32") + + self.moe_aux_loss_lambda = paddle.to_tensor( + config.moe_aux_loss_lambda, dtype="float32" + ) + self.moe_z_loss_lambda = paddle.to_tensor( + config.moe_z_loss_lambda, dtype="float32" + ) + self.moe_orthogonal_loss_lambda = paddle.to_tensor( + config.moe_orthogonal_loss_lambda, dtype="float32" + ) + if self.moe_aux_loss_lambda.ndim == 0: + self.moe_aux_loss_lambda = self.moe_aux_loss_lambda.unsqueeze(0) + if self.moe_z_loss_lambda.ndim == 0: + self.moe_z_loss_lambda = self.moe_z_loss_lambda.unsqueeze(0) + if self.moe_orthogonal_loss_lambda.ndim == 0: + self.moe_orthogonal_loss_lambda = self.moe_orthogonal_loss_lambda.unsqueeze( + 0 + ) + + self.experts_type_ids = None + if config.moe_orthogonal_loss_lambda: + if hasattr(fleet.fleet, "_user_defined_strategy"): + strategy = fleet.fleet._user_defined_strategy + sharding_configs = strategy.hybrid_configs["sharding_configs"] + pp_config = strategy.hybrid_configs["pp_configs"] + assert ( + not sharding_configs.comm_overlap + and not pp_config.sharding_comm_overlap + ), "orthogonal loss will cause twice gradient accumulate, will break pp/sharding overlap" + + self.eps = paddle.to_tensor([1e-12], dtype="float32") + if config.multimodel_experts: + if config.moe_use_hard_gate: + self.num_experts_list = [] + self.experts_type_mask = [] + experts_ids = paddle.zeros( + [sum(self.num_experts)], dtype="int64" + ).reshape([config.moe_world_size, -1]) + offset = 0 + for i, expert_num in enumerate(self.num_experts): + experts_ids[ + :, offset : offset + expert_num // config.moe_world_size + ] = i + offset += expert_num // config.moe_world_size + self.experts_type_ids = experts_ids.reshape([-1]) + logger.info( + f"use moe_use_hard_gate, experts_ids: {self.experts_type_ids}" + ) + for i, expert_num in enumerate(self.num_experts): + self.experts_type_mask.append( + self.experts_type_ids == i, + ) + self.num_experts_list.append(expert_num) + else: + assert ( + not config.moe_group_experts + ), "group_experts must use hard_gate when multimodel_experts is True" + else: + self.num_experts_list = [self.num_experts] + if gate_weight is not None: + self.weight = gate_weight + assert ( + not self.config.moe_use_token_type_bias + ), "gate_weights is from outside, token_type_bias can't be used" + logger.info("moe use gate_weight from outside") + self._cast_to_low_precision = False + self._cast_to_low_precison = False + else: + self._create_gate_parameter() + logger.info( + f"{config.moe_gate}: w/ capacity: {self.cap} experts:{self.num_experts} " + f"use_token_type_bias:{self.use_token_type_bias} gate_act:{config.moe_gate_act} " + f"norm_gate_logits={self.norm_gate_logits} use_correction_bias={self.use_correction_bias}" + ) + + def _create_gate_parameter(self): + + if self.config.multimodel_experts: + self.moe_z_loss_lambda = self.moe_z_loss_lambda.expand( + len(self.num_experts) + ) + self.moe_aux_loss_lambda = self.moe_aux_loss_lambda.expand( + len(self.num_experts) + ) + self.moe_orthogonal_loss_lambda = self.moe_orthogonal_loss_lambda.expand( + len(self.num_experts) + ) + + for i, num_experts in enumerate(self.num_experts): + if i == 1: + with paddle.utils.unique_name.guard(f"mm_gate_{self.layer_idx}_"): + p = self.create_parameter( + shape=[self.model_dim, num_experts], + dtype="float32", + attr=paddle.ParamAttr( + name=unique_name.generate("moe_gate") + ), + ) + else: + p = self.create_parameter( + shape=[self.model_dim, num_experts], + dtype="float32", + attr=paddle.ParamAttr(name=unique_name.generate("moe_gate")), + ) + p.expert_type = f"expert_type_{i}" + self.add_parameter( + ("weight" if i == 0 else f"weight_{i}"), + p, + ) + else: + self.weight = self.create_parameter( + shape=[self.model_dim, self.num_experts], + dtype="float32", + attr=paddle.ParamAttr(name=unique_name.generate("moe_gate")), + ) + logger.info(f"moe-Gate, {self.weight}") + + if self.use_token_type_bias: + if self.config.multimodel_experts: + assert ( + not self.config.moe_use_hard_gate + ), "multimodel_experts with hard_gate is not support token_type_bias." + num_experts = ( + sum(self.num_experts) + if self.config.multimodel_experts + else self.num_experts + ) + bias_type_num = ( + len(self.num_experts) if self.config.multimodel_experts else 1 + ) + self.bias = self.create_parameter( + shape=[bias_type_num, num_experts], + dtype="float32", + attr=paddle.ParamAttr( + name=unique_name.generate("moe_gate_bias"), + initializer=paddle.nn.initializer.Assign( + np.zeros([bias_type_num, num_experts]) + ), + ), + ) + logger.info(f"using token type bias, bias: {self.bias},") + self._cast_to_low_precision = False + self._cast_to_low_precison = False + + def get_gate_weight(self, transform_weight): + if not self.config.multimodel_experts: + return self.weight + if not transform_weight: + return paddle.concat( + [ + getattr(self, "weight" if i == 0 else f"weight_{i}") + for i in range(len(self.num_experts)) + ], + -1, + ) + weight = paddle.zeros( + [ + self.model_dim, + self.config.moe_world_size, + sum(self.num_experts) // self.config.moe_world_size, + ], + dtype="float32", + ) + offset = 0 + for i, num_experts in enumerate(self.num_experts): + weight[ + :, :, offset : offset + num_experts // self.config.moe_world_size + ] = getattr(self, "weight" if i == 0 else f"weight_{i}").reshape( + [self.model_dim, self.config.moe_world_size, -1] + ) + offset += num_experts // self.config.moe_world_size + weight = weight.reshape([self.model_dim, -1]) + + return weight + + def forward( + self, + input: Tensor, + token_type_ids: Tensor = None, + transform_weight: bool = True, + correction_bias: Tensor = None, + ): + pass + + def get_capacity(self, num_tokens, cap_factor=None): + + num_experts = ( + sum(self.num_experts) + if self.config.multimodel_experts + else self.num_experts + ) + if cap_factor is not None: + cap = cap_factor + else: + if self.training: + cap = self.cap[0] + elif num_tokens < num_experts: + cap = self.cap[2] + else: + cap = self.cap[1] + capacity = int(cap * num_tokens // num_experts) + assert ( + capacity > 0 + ), f"requires capacity to >= 0. cap={cap}, num_tokens={num_tokens}" + return capacity + + def top2_gating(self, logits, cap=None, correction_bias=None): + + l_zloss = self._cal_z_loss(logits) + gates = self.act(logits) + + assert logits.ndim == 2, logits.shape + num_tokens = gates.shape[0] + num_experts = gates.shape[1] + capacity = self.get_capacity(logits.shape[0], cap) + + score_for_argmax = ( + gates + correction_bias.unsqueeze(0) + if correction_bias is not None + else gates + ) + indices1_s = paddle.argmax(score_for_argmax, axis=1) + mask1 = F.one_hot(indices1_s, num_classes=num_experts).cast(paddle.int64) + + l_aux = self._cal_aux_loss(gates, mask1.sum(axis=0), self.num_experts_tensor) + + if self.training and not self.no_jitter: + gumbels = ( + -paddle.empty_like( + logits, + ) + .exponential_() + .log() + ) + logits_w_noise = logits + gumbels + else: + logits_w_noise = logits + + logits_except1 = masked_fill( + logits_w_noise, mask1.cast(paddle.bool), float("-inf") + ) + score_for_argmax = ( + self.act(logits_except1) + correction_bias.unsqueeze(0) + if correction_bias is not None + else logits_except1 + ) + indices2_s_original = paddle.argmax(score_for_argmax, axis=1) + + if self.training and self.sinkhorn_2gate: + r = paddle.ones(num_tokens, "float32") / num_tokens + + c = capacity - mask1.cast("float32").sum(0) + c = paddle.maximum(c, paddle.zeros_like(c)) + c /= c.sum() + + pi, _ = compute_optimal_transport( + -logits_except1.cast("float32").detach(), r, c, lam=self.sinkhorn_temp + ) + pi = masked_fill(pi, mask1.cast(paddle.bool), float("-inf")) + indices2_s = paddle.argmax(pi, axis=1) + else: + indices2_s = indices2_s_original + + mask2 = F.one_hot(indices2_s, num_classes=self.num_experts).cast(paddle.int64) + + locations1 = paddle.cumsum(mask1, axis=0) - 1 + locations2 = paddle.cumsum(mask2, axis=0) - 1 + locations2 += paddle.sum(mask1, axis=0, keepdim=True) + + mask1 *= (locations1 < capacity).cast(paddle.int64) + mask2 *= (locations2 < capacity).cast(paddle.int64) + + locations1_s = paddle.sum(locations1 * mask1, axis=1) + locations2_s = paddle.sum(locations2 * mask2, axis=1) + + mask1_float = mask1.cast(paddle.float32) + mask2_float = mask2.cast(paddle.float32) + gates1_s = (gates * mask1_float).sum(axis=-1) + gates2_s = (gates * mask2_float).sum(axis=-1) + + if self.norm_gate_logits: + denom_s = gates1_s + gates2_s + denom_s = paddle.clip(denom_s, min=1e-6) + gates1_s /= denom_s + gates2_s /= denom_s + if self.training and self.expert_drop: + gates2_s = paddle.where( + 2 * gates2_s < paddle.rand_like(gates2_s), + paddle.zeros_like(gates2_s), + gates2_s, + ) + + gates1 = gates1_s.unsqueeze(1) * mask1_float + gates2 = gates2_s.unsqueeze(1) * mask2_float + + expert1_index = paddle.argmax(gates1, -1) + combine1_weight = paddle.max(gates1, -1, keepdim=True) + scatter1_index = expert1_index * capacity + locations1_s + scatter1_index = scatter1_index.cast("int64") + dispatch1_mask = combine1_weight.cast(paddle.bool).detach() + + expert2_index = paddle.argmax(gates2, -1) + combine2_weight = paddle.max(gates2, -1, keepdim=True) + scatter2_index = expert2_index * capacity + locations2_s + scatter2_index = scatter2_index.cast("int64") + dispatch2_mask = combine2_weight.cast(paddle.bool).detach() + + return ( + capacity, + paddle.concat((dispatch1_mask, dispatch2_mask), 1), + paddle.concat((combine1_weight, combine2_weight), 1), + paddle.stack((scatter1_index, scatter2_index), 1), + l_aux, + l_zloss, + ) + + def _cal_aux_loss( + self, + gate_prob, + dispatch_mask, + num_experts=None, + use_group=None, + tokens_mask=None, + dispatch_tokens_mask=None, + ): + + if self.act is F.sigmoid: + gate_prob = gate_prob / gate_prob.sum(-1, keepdim=True) + + if self.use_correction_bias: + if tokens_mask is not None: + gate_prob_this_modality = gate_prob[tokens_mask.astype("bool")] + if gate_prob_this_modality.shape[0]: + _, top_idx = gate_prob_this_modality.topk( + k=self.config.moe_k, axis=-1 + ) + mask = paddle.zeros_like( + gate_prob_this_modality + ).put_along_axis(top_idx, paddle.to_tensor(1.0), axis=1) + dispatch_mask = paddle.sum(mask.cast(paddle.int64), axis=0) + else: + dispatch_mask = paddle.zeros(gate_prob.shape[-1], dtype="int64") + dist.stream.all_reduce( + dispatch_mask, + group=self.group, + use_calc_stream=True, + ) + else: + _, top_idx = gate_prob.topk(k=self.config.moe_k, axis=-1) + + mask = paddle.zeros_like(gate_prob).put_along_axis( + top_idx, paddle.to_tensor(1.0), axis=1 + ) + dispatch_mask = paddle.sum(mask.cast(paddle.int64), axis=0) + + if num_experts is None: + num_experts = self.num_experts_tensor + if use_group is None: + use_group = self.config.moe_group_experts + + return cal_aux_loss_func( + gate_prob, + dispatch_mask, + tokens_mask, + dispatch_tokens_mask, + num_experts, + use_group, + self.config.moe_k, + self.global_aux_loss, + self.rank if self.global_aux_loss else None, + self.group if self.global_aux_loss else None, + ) + + +class TopKGateFused(Top2Gate): + + def forward( + self, + input: Tensor, + token_type_ids=None, + transform_weight=True, + ) -> Tuple[Tensor, Tensor, Tensor]: + + capacity = self.get_capacity(input.shape[0]) + weight = self.get_gate_weight(transform_weight) + with paddle.amp.auto_cast(False): + + logits = gate_detach_matmul( + input, weight, self.fuse_gate_detach_matmul, self.use_fake_gate + ) + if self.use_token_type_bias: + assert token_type_ids is not None + assert ( + token_type_ids.max() < self.bias.shape[0] + ), f"token_type_ids {token_type_ids.max()} >= bias shape {self.bias.shape[0]}" + bias = self.bias[token_type_ids] + logits = logits + bias + router_loss = paddle.zeros([1], dtype="float32") + router_loss.stop_gradient = False + + return logits, capacity, router_loss + + +class TopKGateFusedAuto(TopKGateFused): + """doc""" + + def __init__(self, config, layer_idx: int, group, gate_weight=None, ipp=0) -> None: + super().__init__(config, layer_idx, group, gate_weight) + self.ipp = ipp + self.weight = dist.shard_tensor( + self.weight, get_flatten_mesh(get_mesh(self.ipp)), [dist.Replicate()] + ) + + def forward( + self, + input: Tensor, + token_type_ids=None, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """ + Args: + input: paddle.Tensor, hidden-states of layer + Retruns: + paddle.Tensor [Seq, Expert, Capacity]: float32, combine weights + paddle.Tensor [Seq, Expert, Capacity]: bool, dispatch mask + Tuple[paddle.Tensor]: `GateOutput` + """ + num_experts = ( + sum(self.num_experts) + if self.config.multimodel_experts + else self.num_experts + ) + if self.training: + cap = self.cap[0] + elif input.shape[0] < num_experts: + cap = self.cap[2] + else: + cap = self.cap[1] + num_tokens = input.shape[0] + global_capacity = int(cap * num_tokens // num_experts) + local_num_tokens = input._local_shape[0] + local_capacity = int(cap * local_num_tokens // num_experts) + + logits, _, router_loss = super().forward(input, token_type_ids) + + return logits, global_capacity, router_loss, local_capacity diff --git a/examples/pre-training/models/sequence_parallel_utils_auto.py b/examples/pre-training/models/sequence_parallel_utils_auto.py new file mode 100644 index 00000000..35f11077 --- /dev/null +++ b/examples/pre-training/models/sequence_parallel_utils_auto.py @@ -0,0 +1,149 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# !/usr/bin/env python3 + +import numpy as np + +import paddle +from paddle import distributed as dist +from paddle.autograd import PyLayer +from paddle.distributed import fleet + + +def scatter(input, group=None, axis=0): + if group is None: + hcg = fleet.get_hybrid_communicate_group() + group = hcg.get_model_parallel_group() + parallelism = group.nranks + if parallelism == 1: + return input.clone() + rank = group.rank + seq_len = input.shape[axis] + assert seq_len % parallelism == 0, ( + f"Input sequence length {seq_len} can't be divided exactly" + f" by sequence parallelism {parallelism}" + ) + interval = seq_len // parallelism + input = paddle.slice( + input, axes=[axis], starts=[interval * rank], ends=[interval * (rank + 1)] + ) + input = paddle.assign(input) + return input + + +def all_gather(input, group=None, axis=0): + if group is None: + hcg = fleet.get_hybrid_communicate_group() + group = hcg.get_model_parallel_group() + parallelism = group.nranks + if parallelism == 1: + return input.clone() + output_shape = input.shape + if axis == 0: + output_shape[axis] = output_shape[axis] * parallelism + output = paddle.empty(shape=output_shape, dtype=input.dtype) + dist.stream.all_gather(output, input, group=group, use_calc_stream=True) + return output + outputs = [ + paddle.empty(output_shape, dtype=input.dtype) for _ in range(parallelism) + ] + dist.stream.all_gather(outputs, input, group=group, use_calc_stream=True) + output = paddle.concat(outputs, axis=axis) + return output + + +class ScatterOp(PyLayer): + + @staticmethod + def forward(ctx, input, axis=0, group=None): + ctx.axis = axis + ctx.group = group + return scatter(input, axis=axis, group=ctx.group) + + @staticmethod + def backward(ctx, grad): + return all_gather(grad, axis=ctx.axis, group=ctx.group) + + +class AllGatherVarlenOp(PyLayer): + + @staticmethod + def forward(ctx, input, group=None): + hcg = fleet.get_hybrid_communicate_group() + if group is None: + group = hcg.get_model_parallel_group() + + shape0 = paddle.to_tensor([input.shape[0]]) + shape0_all = paddle.empty(shape=[group.nranks], dtype=shape0.dtype) + dist.stream.all_gather(shape0_all, shape0, group=group, use_calc_stream=True) + shape0_all = shape0_all.numpy() + max_shape0 = shape0_all.max() + + indices = [] + for idx, s in enumerate(shape0_all): + offset = idx * max_shape0 + indices.append(list(range(offset, offset + s))) + indices = np.concatenate(indices, axis=0) + indices = indices.reshape([-1] + [1] * (len(input.shape) - 1)) + indices = paddle.to_tensor(indices, dtype=paddle.int32) + + padding = max_shape0 - input.shape[0] + + ctx.shape0 = input.shape[0] + ctx.max_shape0 = max_shape0 + ctx.shape0_all = shape0_all + ctx.padding = padding + ctx.indices = indices + ctx.group = group + + if padding > 0: + input_shape = input.shape + input_shape[0] = padding + padding_tensor = paddle.empty(shape=input_shape, dtype=input.dtype) + input = paddle.concat([input, padding_tensor], axis=0) + output = all_gather(input, group) + output = paddle.take_along_axis(output, indices, axis=0) + + return output + + @staticmethod + def backward(ctx, grad): + input_shape = grad.shape + input_shape[0] = ctx.max_shape0 * ctx.shape0_all.shape[0] + output = paddle.zeros(shape=input_shape, dtype=grad.dtype) + + grad = paddle.scatter(output, ctx.indices, grad) + + grad = scatter(grad, ctx.group) + + if ctx.padding > 0: + grad = grad[: ctx.shape0] + return grad + + +def sequence_parallel_sparse_mask_labels(labels, ignore_label=-100): + hcg = fleet.get_hybrid_communicate_group() + group = hcg.get_model_parallel_group() + labels = labels.flatten() + labels_local = paddle.split(labels, group.nranks)[group.rank] + + tgt_index = paddle.nonzero(labels_local != ignore_label).squeeze() + if tgt_index.numel() == 0: + tgt_index = paddle.to_tensor([0]) + + tgt_index = tgt_index.reshape([-1]).astype(paddle.int32) + labels_local_gather = paddle.take_along_axis(labels_local, tgt_index, axis=0) + labels_all_gather = AllGatherVarlenOp.apply(labels_local_gather) + return labels_all_gather, tgt_index.reshape([-1, 1]) diff --git a/examples/pre-training/models/utils_auto.py b/examples/pre-training/models/utils_auto.py new file mode 100644 index 00000000..364323e7 --- /dev/null +++ b/examples/pre-training/models/utils_auto.py @@ -0,0 +1,93 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Any, Callable, List + +import paddle +from paddle import framework + +logger = logging.getLogger(__name__) + +try: + import moe_permutation + +except ImportError: + moe_permutation = None + logger.warning("moe_permutation is not installed.") + + +def detach_and_requires_grad_(*args): + ret = [a.detach() if a is not None else None for a in args] + for r, a in zip(ret, args): + if a is not None: + r.stop_gradient = a.stop_gradient + return ret + + +class FakeClone(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, input): + if input.is_contiguous(): + fake_output = paddle.empty_like(input) + input._share_buffer_to(fake_output) + else: + fake_output = input.clone() + return fake_output + + @staticmethod + def backward(ctx, grad_output): + return grad_output + + +def manual_backward(f: Callable, is_first_fwd: bool, *args: List[Any]): + tracer = framework._dygraph_tracer() + orig = tracer._has_grad + if not is_first_fwd: + tracer._has_grad = True + + detached_args = detach_and_requires_grad_(*args) + detached_args_clone = [ + FakeClone.apply(a) if a is not None else None for a in detached_args + ] + out = f(*detached_args_clone) + if isinstance(out, list): + out = tuple(out) + elif not isinstance(out, tuple): + out = (out,) + + if is_first_fwd: + tracer._has_grad = orig + return None, out + + out_cached = [FakeClone.apply(o) for o in out if o is not None] + + for o in out_cached: + o._clear_dataptr() + tracer._has_grad = orig + + def bwd_f(*grad): + nonlocal out_cached, detached_args, f + grad = list(grad) + grad = [g for g in grad if g is not None] + assert grad and out_cached, (len(grad), len(out_cached)) + grad, out_cached = zip( + *[(g, o) for g, o in zip(grad, out_cached) if not o.stop_gradient] + ) + + assert len(grad) == len(out_cached), (len(grad), len(out_cached), f) + paddle.autograd.backward(out_cached, grad) + return tuple([t.grad for t in detached_args if t is not None]) + + return bwd_f, out diff --git a/examples/pre-training/scripts/train_96_auto.sh b/examples/pre-training/scripts/train_96_auto.sh new file mode 100644 index 00000000..25f139c8 --- /dev/null +++ b/examples/pre-training/scripts/train_96_auto.sh @@ -0,0 +1,51 @@ +#!/bin/bash + +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +export CUDA_MODULE_LOADING=LAZY +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export PYTHONUNBUFFERED=1 +unset GLOG_vmodule GLOG_v +export PADDLE_DISABLE_CUDNN_FA=1 +export FLAGS_use_auto_growth_pinned_allocator=True +export FLAGS_pipeline_nccl_comm_init_option=1 +export FLAGS_sharding_v2_check_zero_padding=1 +export FLAGS_use_paddle_recall_error=0 +export FLAGS_tcp_max_syn_backlog=16384 +export FLAGS_call_stack_level=2 + + +SM=`nvidia-smi --query-gpu=compute_cap --format=csv | tail -n 1 | sed 's/\.//g'` +if [ $SM -eq 90 ] +then + export FLAGS_flash_attn_version=3 +else + export FLAGS_flash_attn_version=2 +fi + + +export FLAGS_enable_fused_ffn_qkv_pass=1 +export FLAGS_enable_pir_api=1 +export FLAGS_enable_moe_utils=true +export FLAGS_call_stack_level=2 + + +export PYTHONPATH=$PYTHONPATH:./ernie + +python -m paddle.distributed.launch \ + --log_dir output/paddle_distributed_logs \ + --run_mode=collective \ + ${script:-ernie/pretrain_auto.py} \ + --config yamls/pretrain_96_auto.yaml diff --git a/examples/pre-training/yamls/pretrain_96_auto.yaml b/examples/pre-training/yamls/pretrain_96_auto.yaml new file mode 100644 index 00000000..39fbef49 --- /dev/null +++ b/examples/pre-training/yamls/pretrain_96_auto.yaml @@ -0,0 +1,91 @@ +env: + HOME: null + +# ---------------------------model args-------------------------------------------------# +model_args: + model_name_or_path: model_configs_auto/ + tokenizer_name: ./ernie/src/tokenizers/tokenizer_model + output_dir: ./output/ + max_seq_length: 4096 + base_seq_length: 4096 + num_consecutive: 32 + sequence_parallel: 1 + enable_global_training_logs: False + moe_use_aux_free_update_coef: 0.001 + global_logging_interval: 10 + model_config: + moe_logging: True + moe_use_aux_free: true + multi_token_pred_depth: 0 + +# ---------------------------trainer args-------------------------------------------------# +trainer_args: + input_dir: "0.4 ./demo_data/data-1-part0 0.6 ./demo_data/data-1-part0" + split: "998,1,1" + use_sp_callback: true + moe_gate_lr_ratio: 0.01 + do_train: True + dataloader_num_workers: 8 + prefetch_factor: 32 + overwrite_output_dir: 1 + disable_tqdm: 1 + logging_steps: 1 + eval_steps: 1000 + eval_iters: -1 + save_steps: 3000 + max_steps: 1000 + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_epsilon: 1e-8 + learning_rate: 2.2e-4 + min_lr: 2.2e-5 + global_batch_size: 2 + gradient_accumulation_steps: 1 + per_device_train_batch_size: 2 + per_device_eval_batch_size: 1 + lr_scheduler: wsd:231084 + decay_function: 1-sqrt + max_grad_norm: 1.0 + use_async_save: True + weight_decay: 0.1 + warmup_steps: 200 + save_total_limit: 5 + bf16: True + fp16_opt_level: "O2" + scale_loss: 4096 + seed: 666 + pre_alloc_memory: 60 + + tensor_parallel_degree: 4 # N7:8, N4:8, N1:4 + pipeline_parallel_degree: 2 # N7:7, N4:4, N1:2 + virtual_pp_degree: 1 # N7:8, N4:8, N1:1 + + n_microbatches: 2 + pipeline_schedule_mode: "VPP" + model_type: "ernie_pp" + + data_parallel_degree: 1 + sharding: "stage1" + sharding_degree: 1 + amp_master_grad: 1 + pipeline_parallel_config: enable_delay_scale_loss + sharding_parallel_config: split_param enable_fuse_optimizer_states + sharding_comm_buffer_size_MB: 2048 + tensor_parallel_config: replace_with_parallel_cross_entropy + + skip_profile_timer: False + ignore_data_skip: 0 + shuffle_consecutive: True + load_sharded_model: True + save_sharded_model: True + ignore_load_lr_and_optim: False + metrics_output_path: ./output/paddle_distributed_logs/ + + use_moe: true + moe_group: mp + log_global_grad_norm: True + enable_optimizer_timer: False + gc_interval: 100000 + + enable_auto_parallel: 1 + to_static: 0