diff --git a/paddleformers/trainer/trainer.py b/paddleformers/trainer/trainer.py index 55fb28d5c09..60d5be23a1e 100644 --- a/paddleformers/trainer/trainer.py +++ b/paddleformers/trainer/trainer.py @@ -166,6 +166,7 @@ get_last_checkpoint, get_scheduler, has_length, + init_optimizer, set_seed, should_skip_data, speed_metrics, @@ -939,7 +940,7 @@ def train( self._memory_tracker.start() if not self.args.enable_auto_parallel: - if not self.args.should_load_sharding_stage1_model: + if not self.args.should_load_sharding_stage1_model and not self.args.using_flex_checkpoint: self._load_from_checkpoint(resume_from_checkpoint) if self.args.should_load_sharding_stage1_model: @@ -959,7 +960,7 @@ def train( if delay_optimizer_creation: self.create_optimizer_and_scheduler(num_training_steps=max_steps) self._load_optimizer_and_scheduler(resume_from_checkpoint) - else: + elif not self.args.using_flex_checkpoint: model = self._wrap_model(self.model_wrapped) # for the rest of this function `model` is the outside model, whether it was wrapped or not if model is not self.model: @@ -967,6 +968,23 @@ def train( if delay_optimizer_creation: self.create_optimizer_and_scheduler(num_training_steps=max_steps) self._load_optimizer_and_scheduler(resume_from_checkpoint) + else: + assert self.args.using_flex_checkpoint, "using_flex_checkpoint should be True" + model = self._wrap_model(self.model_wrapped) + if model is not self.model: + self.model_wrapped = model + + if delay_optimizer_creation: + self.create_optimizer_and_scheduler(num_training_steps=max_steps) + + if resume_from_checkpoint is not None: + model_sharded_state_dict = self.model.sharded_state_dict() + self.optimizer.sharded_state_dict(model_sharded_state_dict) + init_optimizer(self.optimizer) + optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict) + sharded_state_dict = {**model_sharded_state_dict, **optimizer_sharded_state_dict} + dist.load_state_dict(sharded_state_dict, resume_from_checkpoint) + self._load_scheduler(resume_from_checkpoint) else: model = self.model_wrapped if delay_optimizer_creation: @@ -2735,6 +2753,10 @@ def _save_checkpoint(self, model, metrics=None): else: self.save_model(output_dir) + model_sharded_state_dict = self.model.sharded_state_dict() + if self.args.using_flex_checkpoint: + os.makedirs(output_dir, exist_ok=True) + # Determine the new best metric / best model checkpoint if metrics is not None and self.args.metric_for_best_model is not None: metric_to_check = self.args.metric_for_best_model @@ -2796,26 +2818,34 @@ def _save_checkpoint(self, model, metrics=None): self.optimizer, output_dir, signal_dir, - self.args.optim_shard_num, ) else: - if self.dp_group.rank > 0: # this should only work for MoE saving - self._save_ckpt_func( - self._filter_moe_no_sync_optimizer_params(), - os.path.join(output_dir, optimizer_name), - saved_signal_path, - ) - - else: - state_dict = self.optimizer.state_dict() - save_path = os.path.join(output_dir, optimizer_name) - if self.args.use_async_save: - assert not strtobool(os.getenv("FLAG_LLM_PDC", "False")), "Dont support FLAG_LLM_PDC" - self._async_optimizer_saver.run( - state_dict, save_path, saved_signal_path=saved_signal_path + if not self.args.using_flex_checkpoint: + if self.dp_group.rank > 0: # this should only work for MoE saving + self._save_ckpt_func( + self._filter_moe_no_sync_optimizer_params(), + os.path.join(output_dir, optimizer_name), + saved_signal_path, ) + else: - self._save_ckpt_func(state_dict, save_path, saved_signal_path) + state_dict = self.optimizer.state_dict() + save_path = os.path.join(output_dir, optimizer_name) + if self.args.use_async_save: + assert not strtobool( + os.getenv("FLAG_LLM_PDC", "False") + ), "Dont support FLAG_LLM_PDC" + self._async_optimizer_saver.run( + state_dict, save_path, saved_signal_path=saved_signal_path + ) + else: + self._save_ckpt_func(state_dict, save_path, saved_signal_path) + else: + optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict) + dist.save_state_dict( + {**model_sharded_state_dict, **optimizer_sharded_state_dict}, + output_dir, + ) else: if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config: global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1 @@ -2835,9 +2865,8 @@ def _save_checkpoint(self, model, metrics=None): self.optimizer, output_dir, signal_dir, - self.args.optim_shard_num, ) - else: + elif not self.args.using_flex_checkpoint: if self.args.data_parallel_rank > 0 and self.args.use_expert_parallel: self._save_ckpt_func( self._filter_moe_no_sync_optimizer_params(), @@ -2851,6 +2880,12 @@ def _save_checkpoint(self, model, metrics=None): saved_signal_path, ) + else: + optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict) + dist.save_state_dict( + {**model_sharded_state_dict, **optimizer_sharded_state_dict}, + output_dir, + ) # FIXME: maybe only save one copy paddle.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) @@ -3122,6 +3157,24 @@ def _save( with open(path, "w") as f: json.dump(model_meta, f) + def _load_scheduler(self, checkpoint): + if checkpoint is None: + self.runtime_timer.stop() + return + + if not self.args.ignore_load_lr_and_optim: + if distributed_isfile(os.path.join(checkpoint, SCHEDULER_NAME)): + self.lr_scheduler.set_state_dict( + paddle.load(distributed_file(os.path.join(checkpoint, SCHEDULER_NAME))) + ) + else: + raise ValueError(f"scheduler-file not found, scheduler:{os.path.join(checkpoint, SCHEDULER_NAME)}") + + if self.do_grad_scaling and distributed_isfile(os.path.join(checkpoint, SCALER_NAME)): + self.scaler.load_state_dict( + paddle.load(distributed_file(os.path.join(checkpoint, SCALER_NAME)), return_numpy=True) + ) + def _load_optimizer_and_scheduler(self, checkpoint): """If optimizer and scheduler states exist, load them.""" self.runtime_timer.start("checkpoint loading time") @@ -3197,18 +3250,7 @@ def _load_optimizer_and_scheduler(self, checkpoint): gc.collect() empty_device_cache() - if not self.args.ignore_load_lr_and_optim: - if distributed_isfile(os.path.join(checkpoint, SCHEDULER_NAME)): - self.lr_scheduler.set_state_dict( - paddle.load(distributed_file(os.path.join(checkpoint, SCHEDULER_NAME))) - ) - else: - raise ValueError(f"scheduler-file not found, scheduler:{os.path.join(checkpoint, SCHEDULER_NAME)}") - - if self.do_grad_scaling and distributed_isfile(os.path.join(checkpoint, SCALER_NAME)): - self.scaler.load_state_dict( - paddle.load(distributed_file(os.path.join(checkpoint, SCALER_NAME)), return_numpy=True) - ) + self._load_scheduler(checkpoint) if self.args.offload_optim: logger.info("Offloading optimizer state...") diff --git a/paddleformers/trainer/trainer_utils.py b/paddleformers/trainer/trainer_utils.py index 0337b3276a3..e2ed5ae0fc9 100644 --- a/paddleformers/trainer/trainer_utils.py +++ b/paddleformers/trainer/trainer_utils.py @@ -53,6 +53,20 @@ from ..utils.tools import get_env_device from .utils.helper import distributed_file +try: + from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import ( + DygraphShardingOptimizerV2, + ) +except: + DygraphShardingOptimizerV2 = None + +try: + from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import ( + DygraphShardingOptimizer, + ) +except: + DygraphShardingOptimizer = None + __all__ = [ "TrainOutput", "PredictionOutput", @@ -1283,3 +1297,62 @@ def _insert_sync(self, sync_var, src, mp_group, sync_mode): # Move it back to pin memory if original_device == "pin_memory": sync_var = paddle.to_tensor(sync_var, place=paddle.CUDAPinnedPlace()) + + +def init_optimizer(optimizer): + """ + Initialize the optimizer's states according to its type. + + For DygraphShardingOptimizer (V1), initializes accumulators for local parameters. + For DygraphShardingOptimizerV2, manually initializes master weights and state dict for sharded parameters. + For other cases, initializes accumulators for all parameters. + + Args: + optimizer: The optimizer instance to be initialized. + """ + if DygraphShardingOptimizer is not None and isinstance(optimizer._inner_opt, DygraphShardingOptimizer): + local_params = optimizer._rank2params[optimizer._sharding_rank] + optimizer._create_accumulators(paddle.base.framework.default_main_program().global_block(), local_params) + return + + elif DygraphShardingOptimizerV2 is not None and isinstance(optimizer._inner_opt, DygraphShardingOptimizerV2): + + def init_param_optimizer_states(param_iter): + master_weights = {} + state_dict = {} + moments = ("moment1_0", "moment2_0") + betas = ("beta1_pow_acc_0", "beta2_pow_acc_0") + for static_name, shape, no_need_master_weights in param_iter: + if not no_need_master_weights: + master_weights[static_name] = paddle.zeros(shape, dtype="float32") + prefix = f"{static_name}_fp32_master_0_" + else: + prefix = f"{static_name}_" + + for moment in moments: + key = f"{prefix}{moment}" + state_dict[key] = paddle.zeros(shape, dtype="float32") + for beta in betas: + key = f"{prefix}{beta}" + state_dict[key] = paddle.zeros((1,), dtype="float32") + return master_weights, state_dict + + def buffer_params(): + for buffer in optimizer._comm_buffer_list: + for param_name, grad_view in buffer._sharding_param_grad_view.items(): + param_begin = grad_view._param_begin + param_end = grad_view._param_end + shape = (param_end - param_begin,) + no_need_master_weights = grad_view._param.dtype == paddle.float32 + if shape[0] > 0: + yield param_name, shape, no_need_master_weights + + master_weights, state_dict = init_param_optimizer_states(buffer_params()) + state_dict["master_weights"] = master_weights + state_dict["LR_Scheduler"] = {"last_epoch": 1, "last_lr": 5e-06} + + optimizer.set_state_dict(state_dict) + return + optimizer._create_accumulators( + paddle.base.framework.default_main_program().global_block(), optimizer._parameter_list + ) diff --git a/paddleformers/trainer/training_args.py b/paddleformers/trainer/training_args.py index c505856532b..9cd1e04c554 100644 --- a/paddleformers/trainer/training_args.py +++ b/paddleformers/trainer/training_args.py @@ -401,6 +401,10 @@ class TrainingArguments: Whether to release gradients during training. Default is `False`. ckpt_quant_stage (`str`, *optional*): Whether activate checkpoint quantization. O0: deactivate, O1: Int8 compression, O2: Int4 compression. (default: O0). + using_flex_checkpoint(`bool`, *optional*): + Whether to use FlexCheckpoint for save and load. Default is False. + aoa_config (`Optional[dict[str, list[str]]]`, *optional*): + The AoA configuration of FlexCheckpoint, used to describe the mapping between model weights and the checkpoint content. Default is None. """ output_dir: str = field( @@ -1080,6 +1084,16 @@ class TrainingArguments: default=False, metadata={"help": "是否开启单路sharding时global norm通信拆分全局通信组为pp通信和mp通信分别做"}, ) + using_flex_checkpoint: Optional[bool] = field( + default=False, + metadata={"help": "Whether use FlexCheckpoint."}, + ) + aoa_config: Optional[dict[str, list[str]]] = field( + default=None, + metadata={ + "help": "The AoA configuration of FlexCheckpoint, used to describe the mapping between model weights and the checkpoint content. Default is None." + }, + ) convert_from_hf: Optional[bool] = field( default=False, metadata={"help": "Load model from HuggingFace safetensors."}, diff --git a/paddleformers/transformers/llama/modeling.py b/paddleformers/transformers/llama/modeling.py index e60a6370faf..6e3b4196db4 100755 --- a/paddleformers/transformers/llama/modeling.py +++ b/paddleformers/transformers/llama/modeling.py @@ -30,6 +30,9 @@ from paddle.distributed import fleet from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker from paddle.distributed.fleet.recompute.recompute import recompute +from paddle.distributed.flex_checkpoint.dcp.sharded_weight import ( + build_sharded_state_dict, +) from ..refined_recompute import ( RRColumnParallelLinear, @@ -1988,6 +1991,14 @@ def forward(self, hidden_states, tensor_parallel_output=None): ) return logits + def sharded_state_dict( + self, + structured_name_prefix: str = "", + ): + axis = 0 if self.transpose_y else 1 + state_dict = self.state_dict(structured_name_prefix="") + return build_sharded_state_dict(state_dict, {"weight": axis}, structured_name_prefix) + class LlamaForCausalLM(LlamaPretrainedModel): enable_to_static_method = True diff --git a/paddleformers/transformers/model_utils.py b/paddleformers/transformers/model_utils.py index d4f8817db88..5b8d47e5929 100644 --- a/paddleformers/transformers/model_utils.py +++ b/paddleformers/transformers/model_utils.py @@ -3199,6 +3199,19 @@ def set_state_dict(self, state_dict, *args, **kwargs): ret = super().set_state_dict(state_dict, *args, **kwargs) return ret + def sharded_state_dict(self, *args, **kwargs): + sharded_state_dict = super().sharded_state_dict(*args, **kwargs) + if self._single_to_pp_mapping is None: + self._set_pipeline_name_mapping() + assert len(self._single_to_pp_mapping) > 0, "The pipeline stage must have parameters!" + + for k in list(sharded_state_dict.keys()): + v = sharded_state_dict.pop(k) + v.tensor_key = self._pp_to_single_mapping[k] + sharded_state_dict[self._pp_to_single_mapping[k]] = v + + return sharded_state_dict + def load_sharded_checkpoint_as_one(folder, variant=None, return_numpy=False): """