Skip to content
374 changes: 374 additions & 0 deletions examples/pre-training/ernie/pretrain_auto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,374 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import time
import json
import numpy as np
import random
import paddle
import paddle.distributed.fleet as fleet
from src.utils_auto import logger
from paddleformers.trainer import (
PdArgumentParser,
get_last_checkpoint,
)
from src.tokenizers.tokenization_eb_v2 import ErnieBotTokenizer
from omegaconf.listconfig import ListConfig
from omegaconf.dictconfig import DictConfig
from src.callbacks_auto import (
GlobalRNGCallback,
)
from models.ernie import (
ErnieForCausalLMAuto,
ErnieForCausalLMAutoPP,
)
from models.ernie.configuration_auto import (
ErnieConfig,
ErnieMoEConfig,
)
from src.trainers import AutoPretrainingTrainer, AutoPreTrainingArguments
from src.utils_auto import (
setup_logger_output_file,
)
from src.utils_auto.misc import global_training_logs

from paddleformers.data.causal_dataset import (
build_train_valid_test_datasets,
check_data_split,
)


from config import get_config

try:
from paddleformers.trainer.trainer_utils import log_trainer_start
except ImportError:

def log_trainer_start():
"""Print main process messgae"""
if "MAIN_PROCESS_STARTED" not in os.environ:
start_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
logger.info(
f"The Training Main Process Started Successfully. time: {start_time}, pid: {os.getpid()}"
)
os.environ["MAIN_PROCESS_STARTED"] = "1"


from paddle.distributed.fleet import collective_perf

log_trainer_start()


def create_pretrained_dataset(args):
assert args.input_dir is not None and len(args.input_dir.split()) > 1

check_data_split(
args.split,
args.do_train,
args.do_eval,
args.do_predict,
)

train_val_test_num_samples = [
args.per_device_train_batch_size
* args.dataset_world_size
* args.max_steps
* args.gradient_accumulation_steps,
args.per_device_eval_batch_size
* args.dataset_world_size
* args.eval_iters
* (args.max_steps // args.eval_steps + 1),
args.per_device_eval_batch_size * args.dataset_world_size * args.test_iters,
]

train_dataset, valid_dataset, test_dataset = build_train_valid_test_datasets(
data_prefix=args.input_dir.split(),
data_impl="mmap",
splits_string=args.split,
train_val_test_num_samples=train_val_test_num_samples,
seq_length=args.max_seq_length + args.multi_token_pred_depth,
seed=args.seed,
skip_warmup=True,
data_cache_path=None,
)

from paddleformers.data import Stack

def _collate_data(data, stack_fn=Stack()):
tokens_ = stack_fn([x["text"] for x in data])

labels = tokens_[:, 1:]
tokens = tokens_[:, :-1]

return {
"input_ids": tokens,
"labels": labels,
}

return train_dataset, valid_dataset, test_dataset, _collate_data


def update_model_config_from_args(config: ErnieConfig, model_args: dict):
for k, v in model_args.items():
if hasattr(config, k):
logger.info(f"update model config: {k} = {v}")
setattr(config, k, v)
else:
logger.warning(f"model config key: {k} does not exist")
return config


def init_parameter(model):

for param in model.parameters():
param.initialize()


def main():
"""Main function"""
config = get_config(verbose=True)
os.makedirs(config.model_args.output_dir, exist_ok=True)
parser = PdArgumentParser(AutoPreTrainingArguments)
if not hasattr(config.trainer_args, "pipeline_parallel_config"):
config.trainer_args.pipeline_parallel_config = ""

if "enable_dp_comm_overlap" in config.trainer_args.pipeline_parallel_config:
logger.warning(
"Pipeline dp_comm_overlap and FusedLinearWithGradAdd can not be used at "
"the same time."
)

if "enable_timer" in config.trainer_args.pipeline_parallel_config:
from paddle.distributed.fleet.meta_parallel.pipeline_parallel import (
PipelineParallel,
)

PipelineParallel.timer_printer = lambda _: None

def formatv(v):
if isinstance(v, ListConfig):
return list(v)
elif isinstance(v, DictConfig):
return dict(v)
return v

model_args = {k: formatv(v) for k, v in dict(config.model_args).items()}
trainer_args = {k: formatv(v) for k, v in dict(config.trainer_args).items()}
(args,) = parser.parse_dict(dict(**model_args, **trainer_args))

if args.strategy.pipeline.enable and args.virtual_pp_degree > 1:
pipeline = args.strategy.pipeline
pipeline.vpp_degree = args.virtual_pp_degree
pipeline.vpp_seg_method = args.virtual_pipeline_seg_method

args.eval_iters = 10
args.test_iters = args.eval_iters * 10

args.use_moe = dict(**dict(config.model_args), **dict(config.trainer_args)).get(
"use_moe", False
)
model_config = dict(getattr(config.model_args, "model_config", {}))
model_config = {k: formatv(v) for k, v in model_config.items()}
logger.info(f"model_config_from_yaml: {json.dumps(model_config, indent=4)}")
setup_logger_output_file(config.model_args.output_dir, args.local_rank)
paddle.set_device(args.device)
np.random.seed(args.seed)
random.seed(args.seed)
paddle.seed(args.seed)

prop = paddle.device.cuda.get_device_properties()
if prop.total_memory < args.pre_alloc_memory * 1024 * 1024 * 1024:
logger.warning(
"Invalid value for `pre_alloc_memory`, so pre-allocating just failed."
)
elif args.pre_alloc_memory > 0:
logger.warning(
f"pre-allocating a tensor whose memory capacity is {args.pre_alloc_memory} GB "
"and then release it."
)
memory_size = int(args.pre_alloc_memory * 1024 * 1024 * 1024)
x = paddle.empty([memory_size], dtype=paddle.uint8)
del x

try:
collective_perf(
"allgather",
round=50,
size_and_time={67108864: 0.00625, 234881024: 0.02, 637534208: 0.057},
)
logger.info("======monitor allgather done!=======\n")
collective_perf(
"allreduce",
round=50,
size_and_time={67108864: 0.02, 134217728: 0.038, 268435456: 0.075},
)
logger.info("======monitor allreduce done!=======\n")
except Exception as e:
logger.warning(f"fleet test unexcepted error! skip exception[{e}]...")

# Detecting last checkpoint.
last_checkpoint = None
if (
os.path.isdir(args.output_dir)
and args.do_train
and not args.overwrite_output_dir
):
last_checkpoint = get_last_checkpoint(args.output_dir)
if last_checkpoint is None and len(os.listdir(args.output_dir)) > 0:
raise ValueError(
f"Output directory ({args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome."
)
elif last_checkpoint is not None and args.resume_from_checkpoint is None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)

# Define the metrics of tasks.
def compute_metrics(p):
preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions

output = paddle.to_tensor(preds)
labels = paddle.to_tensor(p.label_ids)
output = [t.astype("float32").cuda() for t in output]
labels = [t[t != tokenizer.ignored_index] for t in labels]
labels = [t.cuda() for t in labels]
all_numel = (
(paddle.concat(labels, 0) != tokenizer.ignored_index).astype("int64").sum()
)
ignored = (paddle.concat(labels, 0) == -100).astype("int64").sum()
labels = all_numel - ignored
output = sum(output)
logger.info(f"output : {output.item()}, labels : {labels.item()}")
nll_loss = output / (labels + 1.0e-6)
ppl = paddle.exp(nll_loss)

return {
"nll_loss": nll_loss.item(),
"ppl": ppl.item(),
"num_token": labels.item(),
}

# model
dtype = "float32"
if args.fp16 and args.fp16_opt_level == "O2":
paddle.set_default_dtype("float16")
dtype = "float16"
elif args.bf16:
paddle.set_default_dtype("bfloat16")
dtype = "bfloat16"

if args.use_moe:
global ErnieConfig, ErnieForCausalLMAuto
ErnieConfig = ErnieMoEConfig

if args.moe_group.lower() in {"mp", "tp", "model", "dummy"}:
logger.info(f"disable moe flag when using moe-group={args.moe_group}")
args.use_moe = False

args.multi_token_pred_depth = model_config.get("multi_token_pred_depth", 0)

cfg = ErnieConfig.from_pretrained(args.model_name_or_path)
cfg.seqlen = args.max_seq_length
cfg.token_balance_seqlen = args.max_seq_length * args.per_device_train_batch_size
cfg.fp16_opt_level = args.fp16_opt_level
cfg.moe_group = args.moe_group
cfg.dtype = dtype
cfg.pipeline_parallel_degree = args.pipeline_parallel_degree
cfg.virtual_pp_degree = args.virtual_pp_degree
if args.tensor_parallel_degree > 1:
cfg.sequence_parallel = args.sequence_parallel
cfg.tensor_parallel_degree = max(
fleet.get_hybrid_communicate_group().get_model_parallel_world_size(), 1
)
cfg.tensor_parallel_rank = max(
fleet.get_hybrid_communicate_group().get_model_parallel_rank(), 0
)
else:
cfg.sequence_parallel = False
cfg.tensor_parallel_degree = 1
cfg.tensor_parallel_rank = 0

cfg.micro_batch_size = args.per_device_train_batch_size
tokenizer = ErnieBotTokenizer.from_pretrained(args.tokenizer_name)
tokenizer.ignored_index = cfg.ignored_index
logger.info(
f"using tokenizer={type(tokenizer)}, bos:{tokenizer.bos_token_id} "
f"eos:{tokenizer.eos_token_id} pad:{tokenizer.pad_token_id} "
)

cfg = update_model_config_from_args(cfg, model_config)

if args.model_type == "ernie":
model_class = ErnieForCausalLMAuto
elif args.model_type == "ernie_pp":
model_class = ErnieForCausalLMAutoPP
else:
raise ValueError(f"not support model_type: {args.model_type}")

with paddle.LazyGuard():
model = model_class(cfg)

cfg = model.config
logger.info(f"using model type:{type(model)}")
paddle.set_default_dtype("float32")

logger.info(f"using model={type(model)}, cfg={cfg}")

# data
logger.info("loading data...")
train_dataset, eval_dataset, test_dataset, data_collator = (
create_pretrained_dataset(args)
)

callbacks = [GlobalRNGCallback()]

init_parameter(model)
model.apply(model.init_weights)
trainer = AutoPretrainingTrainer(
model=model,
args=args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
compute_metrics=compute_metrics,
callbacks=callbacks,
)
global_training_logs.accumulate = args.gradient_accumulation_steps
checkpoint = None
if args.resume_from_checkpoint is not None:
checkpoint = args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint

# Training
if args.do_train:
train_result = trainer.train(resume_from_checkpoint=checkpoint)
metrics = train_result.metrics
trainer.save_model(args.output_dir)
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()

# Evaluate and tests model
if args.do_eval:
eval_metrics = trainer.evaluate()
trainer.log_metrics("eval", eval_metrics)


if __name__ == "__main__":
main()
Loading
Loading