diff --git a/llm/run_pretrain_llm.sh b/llm/run_pretrain_llm.sh new file mode 100644 index 000000000000..f9755a9ce7b0 --- /dev/null +++ b/llm/run_pretrain_llm.sh @@ -0,0 +1,67 @@ +#!/bin/bash + +# 设置环境变量 +export PYTHONPATH=../:$PYTHONPATH +export FLAGS_call_stack_level=3 +export NVIDIA_TF32_OVERRIDE=0 +export FLAGS_cudnn_deterministic=True +export FLAGS_embedding_deterministic=1 + +# 设置输出目录 +task_name="llama_pretrain" +case_out_dir="output/${task_name}" +case_log_dir="output/${task_name}_log" + +# 清理旧的输出目录 +rm -rf $case_out_dir +rm -rf $case_log_dir + +# 启动训练 +python -u -m paddle.distributed.launch \ + --gpus "0,1,2,3" \ + --log_dir "$case_log_dir" \ + run_pretrain.py \ + --model_name_or_path "meta-llama/Llama-2-7b" \ + --tokenizer_name_or_path "meta-llama/Llama-2-7b" \ + --input_dir "./data" \ + --split "949,50,1" \ + --num_hidden_layers 4 \ + --output_dir "$case_out_dir" \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 8 \ + --per_device_eval_batch_size 8 \ + --tensor_parallel_degree 4 \ + --pipeline_parallel_degree 1 \ + --tensor_parallel_config "enable_delay_scale_loss enable_mp_async_allreduce enable_mp_skip_c_identity" \ + --pipeline_parallel_config "enable_delay_scale_loss enable_release_grads disable_partial_send_recv enable_overlap_p2p_comm" \ + --virtual_pp_degree 1 \ + --sequence_parallel 0 \ + --use_flash_attention 0 \ + --use_fused_rms_norm 0 \ + --enable_linear_fused_grad_add 0 \ + --learning_rate 3e-05 \ + --logging_steps 1 \ + --max_steps 10 \ + --save_steps 11 \ + --eval_steps 1000 \ + --weight_decay 0.01 \ + --fp16 1 \ + --fp16_opt_level "O2" \ + --amp_master_grad 1 \ + --max_grad_norm 1.0 \ + --dataloader_num_workers 1 \ + --continue_training 0 \ + --do_train true \ + --do_eval false \ + --do_predict false \ + --disable_tqdm true \ + --skip_profile_timer true \ + --recompute 0 \ + --save_total_limit 2 \ + --device "gpu" \ + --save_sharded_model 0 \ + --unified_checkpoint 0 \ + --using_flex_checkpoint 1 \ + --fuse_attention_qkv true \ + --fuse_attention_ffn true \ + # --resume_from_checkpoint "./output/llama_pretrain/checkpoint-1" diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index fdce86316878..2c2483cbe161 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -159,6 +159,7 @@ get_last_checkpoint, get_scheduler, has_length, + init_optimizer, set_seed, should_skip_data, speed_metrics, @@ -197,7 +198,6 @@ if is_datasets_available(): import datasets - try: from paddle.distributed.fleet.utils import mix_precision_utils except: @@ -914,7 +914,7 @@ def train( self._memory_tracker.start() if not self.args.enable_auto_parallel: - if not self.args.should_load_sharding_stage1_model: + if not self.args.should_load_sharding_stage1_model and not self.args.using_flex_checkpoint: self._load_from_checkpoint(resume_from_checkpoint) if self.args.should_load_sharding_stage1_model: @@ -934,7 +934,7 @@ def train( if delay_optimizer_creation: self.create_optimizer_and_scheduler(num_training_steps=max_steps) self._load_optimizer_and_scheduler(resume_from_checkpoint) - else: + elif not self.args.using_flex_checkpoint: model = self._wrap_model(self.model_wrapped) # for the rest of this function `model` is the outside model, whether it was wrapped or not if model is not self.model: @@ -942,6 +942,24 @@ def train( if delay_optimizer_creation: self.create_optimizer_and_scheduler(num_training_steps=max_steps) self._load_optimizer_and_scheduler(resume_from_checkpoint) + else: + assert self.args.using_flex_checkpoint, "default using flex_checkpoint!" + + model = self._wrap_model(self.model_wrapped) + if model is not self.model: + self.model_wrapped = model + + if delay_optimizer_creation: + self.create_optimizer_and_scheduler(num_training_steps=max_steps) + + if resume_from_checkpoint is not None: + model_sharded_state_dict = self.model.sharded_state_dict() + self.optimizer.sharded_state_dict(model_sharded_state_dict) + init_optimizer(self.optimizer) + optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict) + sharded_state_dict = {**model_sharded_state_dict, **optimizer_sharded_state_dict} + dist.load_state_dict(sharded_state_dict, resume_from_checkpoint, aoa_config=self.args.aoa_config) + self._load_scheduler(resume_from_checkpoint) else: model = self.model_wrapped if delay_optimizer_creation: @@ -1342,6 +1360,7 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg): logger.warning( f"optimizer not run, scale_before: {scale_before_value[0]}, scale_after: {scale_after_value[0]}" ) + elif isinstance(self.optimizer, HybridParallelOptimizer): self.optimizer._step(parameters_list) else: @@ -1597,8 +1616,8 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval, logs["learning_rate"] = float("{0:.3e}".format(self._get_learning_rate())) logs["global_step"] = int(self.state.global_step) - if in_auto_parallel_align_mode(): - logs["loss_md5"] = avg_loss._md5sum() + # if in_auto_parallel_align_mode(): + logs["loss_md5"] = avg_loss._md5sum() divisor = 2**30 # TODO(@gexiao): replace these codes with unified APIs in Paddle @@ -1968,7 +1987,6 @@ def apply_decay_param_fun(x): grad_clip=nn.ClipGradByGlobalNorm(self.args.max_grad_norm) if self.args.max_grad_norm > 0 else None, **optimizer_kwargs, ) - return self.optimizer def _apply_to_optimizer(self, action): @@ -2234,7 +2252,6 @@ def _wrap_model(self, model, training=True): mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype) assert self.optimizer is not None, "optimizer is empty!" self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer) - # Pipeline mode if in_pipeline_parallel_mode: if self.args.amp_master_grad: @@ -2284,7 +2301,6 @@ def get_expected_keys(inputs, keys): if self.args.amp_master_grad: self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer) self.optimizer = fleet.distributed_optimizer(self.optimizer) - if ( hasattr(self.args, "enable_sharding_comm_overlap") and self.args.enable_sharding_comm_overlap @@ -2292,7 +2308,6 @@ def get_expected_keys(inputs, keys): and "split_param" in split_parallel_config(self.args.sharding_parallel_config) ): model.register_sharding_comm_overlap_hook(self.optimizer) - # No pipeline mode, sharding only if not in_pipeline_parallel_mode and in_sharding_parallel_mode: # Sharded DDP! @@ -2306,7 +2321,6 @@ def get_expected_keys(inputs, keys): model = paddle.distributed.fleet.meta_parallel.TensorParallel( model, hcg, strategy=fleet.fleet._user_defined_strategy ) - if ShardingOption.SHARD_OP in self.args.sharding: if self.args.amp_master_grad: mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype) # return value has no use @@ -2348,6 +2362,7 @@ def get_expected_keys(inputs, keys): offload=cpu_offload, **extra_kwargs, ) + if ShardingOption.SHARD_GRAD_OP in self.args.sharding and self.args.amp_master_grad: assert hasattr(optimizer, "use_main_grad"), ( "Current installed paddle doesn't support sharding stage 2 with main grad, " @@ -2373,7 +2388,6 @@ def get_expected_keys(inputs, keys): if self.args.amp_master_grad: self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer) self.optimizer = fleet.distributed_optimizer(self.optimizer) - # stage1 has v1 and v2 version if in_sharding_parallel_mode and ShardingOption.SHARD_OP in self.args.sharding: if "split_param" in self.args.sharding_parallel_config: @@ -2388,7 +2402,6 @@ def get_expected_keys(inputs, keys): and "enable_stage1_broadcast_overlap" in self.args.sharding_parallel_config ): self.optimizer._set_broadcast_overlap(True, model) - return model def _prepare_input(self, data: Union[paddle.Tensor, Any]) -> Union[paddle.Tensor, Any]: @@ -2700,6 +2713,10 @@ def _save_checkpoint(self, model, metrics=None): else: self.save_model(output_dir) + model_sharded_state_dict = self.model.sharded_state_dict() + if self.args.using_flex_checkpoint: + os.makedirs(output_dir, exist_ok=True) + # Determine the new best metric / best model checkpoint if metrics is not None and self.args.metric_for_best_model is not None: metric_to_check = self.args.metric_for_best_model @@ -2763,23 +2780,32 @@ def _save_checkpoint(self, model, metrics=None): signal_dir, ) else: - if self.dp_group.rank > 0: # this should only work for MoE saving - self._save_ckpt_func( - self._filter_moe_no_sync_optimizer_params(), - os.path.join(output_dir, optimizer_name), - saved_signal_path, - ) - - else: - state_dict = self.optimizer.state_dict() - save_path = os.path.join(output_dir, optimizer_name) - if self.args.use_async_save: - assert not strtobool(os.getenv("FLAG_LLM_PDC", "False")), "Dont support FLAG_LLM_PDC" - self._async_optimizer_saver.run( - state_dict, save_path, saved_signal_path=saved_signal_path + if not self.args.using_flex_checkpoint: + if self.dp_group.rank > 0: # this should only work for MoE saving + self._save_ckpt_func( + self._filter_moe_no_sync_optimizer_params(), + os.path.join(output_dir, optimizer_name), + saved_signal_path, ) + else: - self._save_ckpt_func(state_dict, save_path, saved_signal_path) + state_dict = self.optimizer.state_dict() + save_path = os.path.join(output_dir, optimizer_name) + if self.args.use_async_save: + assert not strtobool( + os.getenv("FLAG_LLM_PDC", "False") + ), "Dont support FLAG_LLM_PDC" + self._async_optimizer_saver.run( + state_dict, save_path, saved_signal_path=saved_signal_path + ) + else: + self._save_ckpt_func(state_dict, save_path, saved_signal_path) + else: + optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict) + dist.save_state_dict( + {**model_sharded_state_dict, **optimizer_sharded_state_dict}, + output_dir, + ) else: if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config: global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1 @@ -2790,7 +2816,12 @@ def _save_checkpoint(self, model, metrics=None): or "remove_master_weight" not in self.args.unified_checkpoint_config ): paddle.save(global_rank, os.path.join(signal_dir, f".master_weight.done.{global_rank}")) - if self.args.should_save or self.args.use_expert_parallel: + + if ( + self.args.should_save + or self.args.use_expert_parallel + or (self.args.data_parallel_degree > 1 and not self.args.use_hybrid_parallel) + ): if not self.args.use_hybrid_parallel: logger.info("Saving optimizer files.") if self.args.unified_checkpoint: @@ -2800,7 +2831,7 @@ def _save_checkpoint(self, model, metrics=None): output_dir, signal_dir, ) - else: + elif not self.args.using_flex_checkpoint: if self.args.data_parallel_rank > 0 and self.args.use_expert_parallel: self._save_ckpt_func( self._filter_moe_no_sync_optimizer_params(), @@ -2814,6 +2845,13 @@ def _save_checkpoint(self, model, metrics=None): saved_signal_path, ) + else: + optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict) + dist.save_state_dict( + {**model_sharded_state_dict, **optimizer_sharded_state_dict}, + output_dir, + ) + # FIXME: maybe only save one copy paddle.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) @@ -3077,6 +3115,24 @@ def _save( with open(path, "w") as f: json.dump(model_meta, f) + def _load_scheduler(self, checkpoint): + if checkpoint is None: + self.runtime_timer.stop() + return + + if not self.args.ignore_load_lr_and_optim: + if distributed_isfile(os.path.join(checkpoint, SCHEDULER_NAME)): + self.lr_scheduler.set_state_dict( + paddle.load(distributed_file(os.path.join(checkpoint, SCHEDULER_NAME))) + ) + else: + raise ValueError(f"scheduler-file not found, scheduler:{os.path.join(checkpoint, SCHEDULER_NAME)}") + + if self.do_grad_scaling and distributed_isfile(os.path.join(checkpoint, SCALER_NAME)): + self.scaler.load_state_dict( + paddle.load(distributed_file(os.path.join(checkpoint, SCALER_NAME)), return_numpy=True) + ) + def _load_optimizer_and_scheduler(self, checkpoint): """If optimizer and scheduler states exist, load them.""" self.runtime_timer.start("checkpoint loading time") @@ -3118,6 +3174,7 @@ def _load_optimizer_and_scheduler(self, checkpoint): and "split_param" in split_parallel_config(self.args.sharding_parallel_config) ): model = self.model_wrapped + opt_state_dict = self.unified_checkpoint_handler.load_unified_optimizer( model=model, optimizer=self.optimizer, @@ -3149,18 +3206,7 @@ def _load_optimizer_and_scheduler(self, checkpoint): optimizer_name = _add_variant(PADDLE_OPTIMIZER_NAME, self.args.optimizer_name_suffix) raise ValueError(f"optimizer-state-dict not found, opt: {os.path.join(checkpoint, optimizer_name)}.") - if not self.args.ignore_load_lr_and_optim: - if distributed_isfile(os.path.join(checkpoint, SCHEDULER_NAME)): - self.lr_scheduler.set_state_dict( - paddle.load(distributed_file(os.path.join(checkpoint, SCHEDULER_NAME))) - ) - else: - raise ValueError(f"scheduler-file not found, scheduler:{os.path.join(checkpoint, SCHEDULER_NAME)}") - - if self.do_grad_scaling and distributed_isfile(os.path.join(checkpoint, SCALER_NAME)): - self.scaler.load_state_dict( - paddle.load(distributed_file(os.path.join(checkpoint, SCALER_NAME)), return_numpy=True) - ) + self._load_scheduler(checkpoint) if self.args.offload_optim: logger.info("Offloading optimizer state...") diff --git a/paddlenlp/trainer/trainer_utils.py b/paddlenlp/trainer/trainer_utils.py index d8d88d1cd4ad..5289394d5f91 100644 --- a/paddlenlp/trainer/trainer_utils.py +++ b/paddlenlp/trainer/trainer_utils.py @@ -28,6 +28,7 @@ import random import threading import time +from collections import defaultdict from contextlib import contextmanager from enum import Enum from pathlib import Path @@ -53,6 +54,21 @@ from ..utils.pdc_sdk import PDCErrorCode, PDCErrorMessageMap, pdc_tool from .utils.helper import distributed_file +try: + from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import ( + DygraphShardingOptimizerV2, + ) +except: + DygraphShardingOptimizerV2 = None + +try: + from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import ( + DygraphShardingOptimizer, + ) +except: + DygraphShardingOptimizer = None + + __all__ = [ "TrainOutput", "PredictionOutput", @@ -1357,3 +1373,56 @@ def set_comm_config(configs, attr, dict_obj): set_comm_config("moe_sharding_configs", "check_nccl_config", nccl_config.get("moe_sharding_check", None)) set_comm_config("default_comm_group_configs", "nccl_config", nccl_config.get("default", None)) return strategy + + +def init_optimizer(optimizer): + """ + Initialize the optimizer's states according to its type. + + For DygraphShardingOptimizer (V1), initializes accumulators for local parameters. + For DygraphShardingOptimizerV2, manually initializes master weights and state dict for sharded parameters. + For other cases, initializes accumulators for all parameters. + + Args: + optimizer: The optimizer instance to be initialized. + """ + if DygraphShardingOptimizer is not None and isinstance(optimizer._inner_opt, DygraphShardingOptimizer): + local_params = optimizer._rank2params[optimizer._sharding_rank] + optimizer._create_accumulators(paddle.base.framework.default_main_program().global_block(), local_params) + return + + elif DygraphShardingOptimizerV2 is not None and isinstance(optimizer._inner_opt, DygraphShardingOptimizerV2): + + def init_param_optimizer_states(param_iter): + master_weights = {} + state_dict = {} + for static_name, shape in param_iter: + master_weights[static_name] = paddle.zeros(shape, dtype="float32") + for moment in ("moment1_0", "moment2_0"): + key = f"{static_name}_fp32_master_0_{moment}" + state_dict[key] = paddle.zeros(shape, dtype="float32") + for beta in ("beta1_pow_acc_0", "beta2_pow_acc_0"): + key = f"{static_name}_fp32_master_0_{beta}" + state_dict[key] = paddle.zeros((1,), dtype="float32") + return master_weights, state_dict + + def buffer_params(): + for buffer in optimizer._comm_buffer_list: + for param_name, grad_view in buffer._sharding_param_grad_view.items(): + numel = grad_view._param.numel().item() + param_begin = grad_view._param_begin + param_end = grad_view._param_end + index = grad_view._index + padding_begin = index + numel + shape = (min(padding_begin, param_end) - param_begin,) + if shape[0] > 0: + yield param_name, shape + + master_weights, state_dict = init_param_optimizer_states(buffer_params()) + state_dict["master_weights"] = master_weights + state_dict["LR_Scheduler"] = {"last_epoch": 1, "last_lr": 5e-06} + optimizer.set_state_dict(state_dict) + return + optimizer._create_accumulators( + paddle.base.framework.default_main_program().global_block(), optimizer._parameter_list + ) diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 30a3e7b3dc62..bdecd8541932 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -407,6 +407,10 @@ class TrainingArguments: Whether to release gradients during training. Default is `False`. ckpt_quant_stage (`str`, *optional*): Whether activate checkpoint quantization. O0: deactivate, O1: Int8 compression, O2: Int4 compression. (default: O0). + using_flex_checkpoint(`bool`, *optional*): + Whether to use FlexCheckpoint for save and load. Default is False. + aoa_config (`Optional[dict[str, list[str]]]`, *optional*): + The AoA configuration of FlexCheckpoint, used to describe the mapping between model weights and the checkpoint content. Default is None. """ output_dir: str = field( @@ -921,6 +925,10 @@ class TrainingArguments: default=False, metadata={"help": "Whether to use async_save instead of paddle.save."}, ) + using_flex_checkpoint: Optional[bool] = field( + default=False, + metadata={"help": "Whether use FlexCheckpoint."}, + ) ordered_save_group_size: int = field( default=0, metadata={ @@ -1082,6 +1090,13 @@ class TrainingArguments: default=None, metadata={"help": "NCCL中通信组的细粒度控制的配置文件路径, 默认值为None, 代表不启用此项配置"} ) + aoa_config: Optional[dict[str, list[str]]] = field( + default=None, + metadata={ + "help": "The AoA configuration of FlexCheckpoint, used to describe the mapping between model weights and the checkpoint content. Default is None." + }, + ) + def __post_init__(self): world_size = paddle.distributed.get_world_size() if in_auto_parallel_align_mode(): @@ -1792,6 +1807,9 @@ def is_segment_parallel_supported(): # DP use hybrid group strategy = fleet.DistributedStrategy() fleet.init(is_collective=True, strategy=strategy) + elif self.using_flex_checkpoint: + strategy = fleet.DistributedStrategy() + fleet.init(is_collective=True, strategy=strategy) else: paddle.distributed.init_parallel_env() @@ -2355,6 +2373,8 @@ def should_save_model_state(self): return True elif self.enable_auto_parallel: return True + elif self.using_flex_checkpoint: + return False elif self.use_hybrid_parallel: # save on dataset rank 0 return self.sharding_parallel_rank == 0 and (self.data_parallel_rank == 0 or self.use_expert_parallel) diff --git a/paddlenlp/trainer/utils/ckpt_converter.py b/paddlenlp/trainer/utils/ckpt_converter.py index 23f085e18f44..556cc0b4754a 100644 --- a/paddlenlp/trainer/utils/ckpt_converter.py +++ b/paddlenlp/trainer/utils/ckpt_converter.py @@ -19,17 +19,17 @@ from typing import List, Union import paddle -from paddle.distributed.checkpoint.load_state_dict import ( +from paddle.distributed.fleet.utils.log_util import logger +from paddle.distributed.flex_checkpoint.dcp.load_state_dict import ( _load_state_dict, get_rank_to_read_files, ) -from paddle.distributed.checkpoint.metadata import ( +from paddle.distributed.flex_checkpoint.dcp.metadata import ( LocalTensorIndex, LocalTensorMetadata, Metadata, ) -from paddle.distributed.checkpoint.utils import flatten_state_dict -from paddle.distributed.fleet.utils.log_util import logger +from paddle.distributed.flex_checkpoint.dcp.utils import flatten_state_dict MODEL_WEIGHT_SUFFIX = ".pdparams" OPTIMIZER_WEIGHT_SUFFIX = ".pdopt" diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 84c9bcc7ff4e..18c48f5470a8 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -30,6 +30,9 @@ from paddle.distributed import fleet from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker from paddle.distributed.fleet.recompute.recompute import recompute +from paddle.distributed.flex_checkpoint.dcp.sharded_weight import ( + build_sharded_state_dict, +) from paddlenlp.transformers.refined_recompute import ( RRColumnParallelLinear, @@ -1367,7 +1370,6 @@ def _get_name_mappings(cls, config: LlamaConfig) -> list[StateDictNameMapping]: @classmethod def _get_tensor_parallel_mappings(cls, config: LlamaConfig, is_split=True): - from paddlenlp.transformers.conversion_utils import split_or_merge_func fn = split_or_merge_func( @@ -1995,6 +1997,14 @@ def forward(self, hidden_states, tensor_parallel_output=None): ) return logits + def sharded_state_dict( + self, + structured_name_prefix: str = "", + ): + axis = 0 if self.transpose_y else 1 + state_dict = self.state_dict(structured_name_prefix="") + return build_sharded_state_dict(state_dict, {"weight": axis}, structured_name_prefix) + class LlamaForCausalLM(LlamaPretrainedModel): enable_to_static_method = True diff --git a/paddlenlp/transformers/model_utils.py b/paddlenlp/transformers/model_utils.py index b5756896c65a..b478b253835f 100644 --- a/paddlenlp/transformers/model_utils.py +++ b/paddlenlp/transformers/model_utils.py @@ -3167,6 +3167,19 @@ def state_dict(self, *args, **kwargs): return state_dict + def sharded_state_dict(self, *args, **kwargs): + sharded_state_dict = super().sharded_state_dict(*args, **kwargs) + if self._single_to_pp_mapping is None: + self._set_pipeline_name_mapping() + assert len(self._single_to_pp_mapping) > 0, "The pipeline stage must have parameters!" + + for k in list(sharded_state_dict.keys()): + v = sharded_state_dict.pop(k) + v.key = self._pp_to_single_mapping[k] + sharded_state_dict[self._pp_to_single_mapping[k]] = v + + return sharded_state_dict + def set_state_dict(self, state_dict, *args, **kwargs): if self._single_to_pp_mapping is None: self._set_pipeline_name_mapping()