From ab5668425232c4bbbeecbee842a8b467043f6243 Mon Sep 17 00:00:00 2001 From: xingmingyyj Date: Fri, 22 Aug 2025 14:15:52 +0800 Subject: [PATCH 1/3] adapter flex_checkpoint --- paddleformers/trainer/trainer.py | 106 +++++++++++++------ paddleformers/trainer/trainer_utils.py | 67 ++++++++++++ paddleformers/trainer/training_args.py | 14 +++ paddleformers/transformers/llama/modeling.py | 11 ++ paddleformers/transformers/model_utils.py | 13 +++ 5 files changed, 179 insertions(+), 32 deletions(-) diff --git a/paddleformers/trainer/trainer.py b/paddleformers/trainer/trainer.py index 8dd4904b4ab..b5af659e1dd 100644 --- a/paddleformers/trainer/trainer.py +++ b/paddleformers/trainer/trainer.py @@ -166,6 +166,7 @@ get_last_checkpoint, get_scheduler, has_length, + init_optimizer, set_seed, should_skip_data, speed_metrics, @@ -936,7 +937,7 @@ def train( self._memory_tracker.start() if not self.args.enable_auto_parallel: - if not self.args.should_load_sharding_stage1_model: + if not self.args.should_load_sharding_stage1_model and not self.args.using_flex_checkpoint: self._load_from_checkpoint(resume_from_checkpoint) if self.args.should_load_sharding_stage1_model: @@ -956,7 +957,7 @@ def train( if delay_optimizer_creation: self.create_optimizer_and_scheduler(num_training_steps=max_steps) self._load_optimizer_and_scheduler(resume_from_checkpoint) - else: + elif not self.args.using_flex_checkpoint: model = self._wrap_model(self.model_wrapped) # for the rest of this function `model` is the outside model, whether it was wrapped or not if model is not self.model: @@ -964,6 +965,23 @@ def train( if delay_optimizer_creation: self.create_optimizer_and_scheduler(num_training_steps=max_steps) self._load_optimizer_and_scheduler(resume_from_checkpoint) + else: + assert self.args.using_flex_checkpoint, "using_flex_checkpoint should be True" + model = self._wrap_model(self.model_wrapped) + if model is not self.model: + self.model_wrapped = model + + if delay_optimizer_creation: + self.create_optimizer_and_scheduler(num_training_steps=max_steps) + + if resume_from_checkpoint is not None: + model_sharded_state_dict = self.model.sharded_state_dict() + self.optimizer.sharded_state_dict(model_sharded_state_dict) + init_optimizer(self.optimizer) + optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict) + sharded_state_dict = {**model_sharded_state_dict, **optimizer_sharded_state_dict} + dist.load_state_dict(sharded_state_dict, resume_from_checkpoint) + self._load_scheduler(resume_from_checkpoint) else: model = self.model_wrapped if delay_optimizer_creation: @@ -2730,6 +2748,10 @@ def _save_checkpoint(self, model, metrics=None): else: self.save_model(output_dir) + model_sharded_state_dict = self.model.sharded_state_dict() + if self.args.using_flex_checkpoint: + os.makedirs(output_dir, exist_ok=True) + # Determine the new best metric / best model checkpoint if metrics is not None and self.args.metric_for_best_model is not None: metric_to_check = self.args.metric_for_best_model @@ -2791,26 +2813,34 @@ def _save_checkpoint(self, model, metrics=None): self.optimizer, output_dir, signal_dir, - self.args.optim_shard_num, ) else: - if self.dp_group.rank > 0: # this should only work for MoE saving - self._save_ckpt_func( - self._filter_moe_no_sync_optimizer_params(), - os.path.join(output_dir, optimizer_name), - saved_signal_path, - ) - - else: - state_dict = self.optimizer.state_dict() - save_path = os.path.join(output_dir, optimizer_name) - if self.args.use_async_save: - assert not strtobool(os.getenv("FLAG_LLM_PDC", "False")), "Dont support FLAG_LLM_PDC" - self._async_optimizer_saver.run( - state_dict, save_path, saved_signal_path=saved_signal_path + if not self.args.using_flex_checkpoint: + if self.dp_group.rank > 0: # this should only work for MoE saving + self._save_ckpt_func( + self._filter_moe_no_sync_optimizer_params(), + os.path.join(output_dir, optimizer_name), + saved_signal_path, ) + else: - self._save_ckpt_func(state_dict, save_path, saved_signal_path) + state_dict = self.optimizer.state_dict() + save_path = os.path.join(output_dir, optimizer_name) + if self.args.use_async_save: + assert not strtobool( + os.getenv("FLAG_LLM_PDC", "False") + ), "Dont support FLAG_LLM_PDC" + self._async_optimizer_saver.run( + state_dict, save_path, saved_signal_path=saved_signal_path + ) + else: + self._save_ckpt_func(state_dict, save_path, saved_signal_path) + else: + optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict) + dist.save_state_dict( + {**model_sharded_state_dict, **optimizer_sharded_state_dict}, + output_dir, + ) else: if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config: global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1 @@ -2830,9 +2860,8 @@ def _save_checkpoint(self, model, metrics=None): self.optimizer, output_dir, signal_dir, - self.args.optim_shard_num, ) - else: + elif not self.args.using_flex_checkpoint: if self.args.data_parallel_rank > 0 and self.args.use_expert_parallel: self._save_ckpt_func( self._filter_moe_no_sync_optimizer_params(), @@ -2846,6 +2875,12 @@ def _save_checkpoint(self, model, metrics=None): saved_signal_path, ) + else: + optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict) + dist.save_state_dict( + {**model_sharded_state_dict, **optimizer_sharded_state_dict}, + output_dir, + ) # FIXME: maybe only save one copy paddle.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) @@ -3110,6 +3145,24 @@ def _save( with open(path, "w") as f: json.dump(model_meta, f) + def _load_scheduler(self, checkpoint): + if checkpoint is None: + self.runtime_timer.stop() + return + + if not self.args.ignore_load_lr_and_optim: + if distributed_isfile(os.path.join(checkpoint, SCHEDULER_NAME)): + self.lr_scheduler.set_state_dict( + paddle.load(distributed_file(os.path.join(checkpoint, SCHEDULER_NAME))) + ) + else: + raise ValueError(f"scheduler-file not found, scheduler:{os.path.join(checkpoint, SCHEDULER_NAME)}") + + if self.do_grad_scaling and distributed_isfile(os.path.join(checkpoint, SCALER_NAME)): + self.scaler.load_state_dict( + paddle.load(distributed_file(os.path.join(checkpoint, SCALER_NAME)), return_numpy=True) + ) + def _load_optimizer_and_scheduler(self, checkpoint): """If optimizer and scheduler states exist, load them.""" self.runtime_timer.start("checkpoint loading time") @@ -3185,18 +3238,7 @@ def _load_optimizer_and_scheduler(self, checkpoint): gc.collect() empty_device_cache() - if not self.args.ignore_load_lr_and_optim: - if distributed_isfile(os.path.join(checkpoint, SCHEDULER_NAME)): - self.lr_scheduler.set_state_dict( - paddle.load(distributed_file(os.path.join(checkpoint, SCHEDULER_NAME))) - ) - else: - raise ValueError(f"scheduler-file not found, scheduler:{os.path.join(checkpoint, SCHEDULER_NAME)}") - - if self.do_grad_scaling and distributed_isfile(os.path.join(checkpoint, SCALER_NAME)): - self.scaler.load_state_dict( - paddle.load(distributed_file(os.path.join(checkpoint, SCALER_NAME)), return_numpy=True) - ) + self._load_scheduler(checkpoint) if self.args.offload_optim: logger.info("Offloading optimizer state...") diff --git a/paddleformers/trainer/trainer_utils.py b/paddleformers/trainer/trainer_utils.py index 0337b3276a3..2812a5bfb19 100644 --- a/paddleformers/trainer/trainer_utils.py +++ b/paddleformers/trainer/trainer_utils.py @@ -53,6 +53,20 @@ from ..utils.tools import get_env_device from .utils.helper import distributed_file +try: + from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import ( + DygraphShardingOptimizerV2, + ) +except: + DygraphShardingOptimizerV2 = None + +try: + from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import ( + DygraphShardingOptimizer, + ) +except: + DygraphShardingOptimizer = None + __all__ = [ "TrainOutput", "PredictionOutput", @@ -1283,3 +1297,56 @@ def _insert_sync(self, sync_var, src, mp_group, sync_mode): # Move it back to pin memory if original_device == "pin_memory": sync_var = paddle.to_tensor(sync_var, place=paddle.CUDAPinnedPlace()) + + +def init_optimizer(optimizer): + """ + Initialize the optimizer's states according to its type. + + For DygraphShardingOptimizer (V1), initializes accumulators for local parameters. + For DygraphShardingOptimizerV2, manually initializes master weights and state dict for sharded parameters. + For other cases, initializes accumulators for all parameters. + + Args: + optimizer: The optimizer instance to be initialized. + """ + if DygraphShardingOptimizer is not None and isinstance(optimizer._inner_opt, DygraphShardingOptimizer): + local_params = optimizer._rank2params[optimizer._sharding_rank] + optimizer._create_accumulators(paddle.base.framework.default_main_program().global_block(), local_params) + return + + elif DygraphShardingOptimizerV2 is not None and isinstance(optimizer._inner_opt, DygraphShardingOptimizerV2): + + def init_param_optimizer_states(param_iter): + master_weights = {} + state_dict = {} + for static_name, shape in param_iter: + master_weights[static_name] = paddle.zeros(shape, dtype="float32") + for moment in ("moment1_0", "moment2_0"): + key = f"{static_name}_fp32_master_0_{moment}" + state_dict[key] = paddle.zeros(shape, dtype="float32") + for beta in ("beta1_pow_acc_0", "beta2_pow_acc_0"): + key = f"{static_name}_fp32_master_0_{beta}" + state_dict[key] = paddle.zeros((1,), dtype="float32") + return master_weights, state_dict + + def buffer_params(): + for buffer in optimizer._comm_buffer_list: + for param_name, grad_view in buffer._sharding_param_grad_view.items(): + numel = grad_view._param.numel().item() + param_begin = grad_view._param_begin + param_end = grad_view._param_end + index = grad_view._index + padding_begin = index + numel + shape = (min(padding_begin, param_end) - param_begin,) + if shape[0] > 0: + yield param_name, shape + + master_weights, state_dict = init_param_optimizer_states(buffer_params()) + state_dict["master_weights"] = master_weights + state_dict["LR_Scheduler"] = {"last_epoch": 1, "last_lr": 5e-06} + optimizer.set_state_dict(state_dict) + return + optimizer._create_accumulators( + paddle.base.framework.default_main_program().global_block(), optimizer._parameter_list + ) diff --git a/paddleformers/trainer/training_args.py b/paddleformers/trainer/training_args.py index 16bce8e4c71..8be4c91b65f 100644 --- a/paddleformers/trainer/training_args.py +++ b/paddleformers/trainer/training_args.py @@ -401,6 +401,10 @@ class TrainingArguments: Whether to release gradients during training. Default is `False`. ckpt_quant_stage (`str`, *optional*): Whether activate checkpoint quantization. O0: deactivate, O1: Int8 compression, O2: Int4 compression. (default: O0). + using_flex_checkpoint(`bool`, *optional*): + Whether to use FlexCheckpoint for save and load. Default is False. + aoa_config (`Optional[dict[str, list[str]]]`, *optional*): + The AoA configuration of FlexCheckpoint, used to describe the mapping between model weights and the checkpoint content. Default is None. """ output_dir: str = field( @@ -1080,6 +1084,16 @@ class TrainingArguments: default=False, metadata={"help": "是否开启单路sharding时global norm通信拆分全局通信组为pp通信和mp通信分别做"}, ) + using_flex_checkpoint: Optional[bool] = field( + default=False, + metadata={"help": "Whether use FlexCheckpoint."}, + ) + aoa_config: Optional[dict[str, list[str]]] = field( + default=None, + metadata={ + "help": "The AoA configuration of FlexCheckpoint, used to describe the mapping between model weights and the checkpoint content. Default is None." + }, + ) def __post_init__(self): world_size = paddle.distributed.get_world_size() diff --git a/paddleformers/transformers/llama/modeling.py b/paddleformers/transformers/llama/modeling.py index 4f91b71b702..ef6fcfd4405 100755 --- a/paddleformers/transformers/llama/modeling.py +++ b/paddleformers/transformers/llama/modeling.py @@ -30,6 +30,9 @@ from paddle.distributed import fleet from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker from paddle.distributed.fleet.recompute.recompute import recompute +from paddle.distributed.flex_checkpoint.dcp.sharded_weight import ( + build_sharded_state_dict, +) from ..refined_recompute import ( RRColumnParallelLinear, @@ -1987,6 +1990,14 @@ def forward(self, hidden_states, tensor_parallel_output=None): ) return logits + def sharded_state_dict( + self, + structured_name_prefix: str = "", + ): + axis = 0 if self.transpose_y else 1 + state_dict = self.state_dict(structured_name_prefix="") + return build_sharded_state_dict(state_dict, {"weight": axis}, structured_name_prefix) + class LlamaForCausalLM(LlamaPretrainedModel): enable_to_static_method = True diff --git a/paddleformers/transformers/model_utils.py b/paddleformers/transformers/model_utils.py index 02c7e1cdfcb..236b9a671cc 100644 --- a/paddleformers/transformers/model_utils.py +++ b/paddleformers/transformers/model_utils.py @@ -3158,6 +3158,19 @@ def set_state_dict(self, state_dict, *args, **kwargs): ret = super().set_state_dict(state_dict, *args, **kwargs) return ret + def sharded_state_dict(self, *args, **kwargs): + sharded_state_dict = super().sharded_state_dict(*args, **kwargs) + if self._single_to_pp_mapping is None: + self._set_pipeline_name_mapping() + assert len(self._single_to_pp_mapping) > 0, "The pipeline stage must have parameters!" + + for k in list(sharded_state_dict.keys()): + v = sharded_state_dict.pop(k) + v.tensor_key = self._pp_to_single_mapping[k] + sharded_state_dict[self._pp_to_single_mapping[k]] = v + + return sharded_state_dict + def load_sharded_checkpoint_as_one(folder, variant=None, return_numpy=False): """ From 3272d16a2e8e56046ffdde282f4f0ea3706728d8 Mon Sep 17 00:00:00 2001 From: xingmingyyj Date: Mon, 25 Aug 2025 16:36:36 +0800 Subject: [PATCH 2/3] re-impl init_optimier --- paddleformers/trainer/trainer_utils.py | 28 ++++++++++++++++---------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/paddleformers/trainer/trainer_utils.py b/paddleformers/trainer/trainer_utils.py index 2812a5bfb19..e2ed5ae0fc9 100644 --- a/paddleformers/trainer/trainer_utils.py +++ b/paddleformers/trainer/trainer_utils.py @@ -1320,31 +1320,37 @@ def init_optimizer(optimizer): def init_param_optimizer_states(param_iter): master_weights = {} state_dict = {} - for static_name, shape in param_iter: - master_weights[static_name] = paddle.zeros(shape, dtype="float32") - for moment in ("moment1_0", "moment2_0"): - key = f"{static_name}_fp32_master_0_{moment}" + moments = ("moment1_0", "moment2_0") + betas = ("beta1_pow_acc_0", "beta2_pow_acc_0") + for static_name, shape, no_need_master_weights in param_iter: + if not no_need_master_weights: + master_weights[static_name] = paddle.zeros(shape, dtype="float32") + prefix = f"{static_name}_fp32_master_0_" + else: + prefix = f"{static_name}_" + + for moment in moments: + key = f"{prefix}{moment}" state_dict[key] = paddle.zeros(shape, dtype="float32") - for beta in ("beta1_pow_acc_0", "beta2_pow_acc_0"): - key = f"{static_name}_fp32_master_0_{beta}" + for beta in betas: + key = f"{prefix}{beta}" state_dict[key] = paddle.zeros((1,), dtype="float32") return master_weights, state_dict def buffer_params(): for buffer in optimizer._comm_buffer_list: for param_name, grad_view in buffer._sharding_param_grad_view.items(): - numel = grad_view._param.numel().item() param_begin = grad_view._param_begin param_end = grad_view._param_end - index = grad_view._index - padding_begin = index + numel - shape = (min(padding_begin, param_end) - param_begin,) + shape = (param_end - param_begin,) + no_need_master_weights = grad_view._param.dtype == paddle.float32 if shape[0] > 0: - yield param_name, shape + yield param_name, shape, no_need_master_weights master_weights, state_dict = init_param_optimizer_states(buffer_params()) state_dict["master_weights"] = master_weights state_dict["LR_Scheduler"] = {"last_epoch": 1, "last_lr": 5e-06} + optimizer.set_state_dict(state_dict) return optimizer._create_accumulators( From f3debb0d5b54dbc353ed2ac6d3614e3ef5cd1887 Mon Sep 17 00:00:00 2001 From: xiaoguoguo626807 Date: Thu, 4 Sep 2025 02:26:10 +0000 Subject: [PATCH 3/3] fix conflict --- paddleformers/trainer/training_args.py | 1 + 1 file changed, 1 insertion(+) diff --git a/paddleformers/trainer/training_args.py b/paddleformers/trainer/training_args.py index 433864d34dd..9cd1e04c554 100644 --- a/paddleformers/trainer/training_args.py +++ b/paddleformers/trainer/training_args.py @@ -1093,6 +1093,7 @@ class TrainingArguments: metadata={ "help": "The AoA configuration of FlexCheckpoint, used to describe the mapping between model weights and the checkpoint content. Default is None." }, + ) convert_from_hf: Optional[bool] = field( default=False, metadata={"help": "Load model from HuggingFace safetensors."},