From 92d9b66ce9b4578c60db18c53a66a0bd97673682 Mon Sep 17 00:00:00 2001 From: zty-king <17786324919@163.com> Date: Wed, 17 Sep 2025 12:31:18 +0000 Subject: [PATCH 1/6] adapt_flex_checkpoint --- paddlenlp/trainer/trainer.py | 168 ++++++++++++++++++----- paddlenlp/trainer/trainer_utils.py | 88 ++++++++++++ paddlenlp/trainer/training_args.py | 103 +++++++++++++- paddlenlp/transformers/llama/modeling.py | 12 +- paddlenlp/transformers/model_utils.py | 13 ++ 5 files changed, 346 insertions(+), 38 deletions(-) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 9a3bad94c59b..f5fbac5bf27f 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -161,6 +161,7 @@ get_last_checkpoint, get_scheduler, has_length, + init_optimizer, set_seed, should_skip_data, speed_metrics, @@ -199,7 +200,6 @@ if is_datasets_available(): import datasets - try: from paddle.distributed.fleet.utils import mix_precision_utils except: @@ -812,6 +812,10 @@ def create_zcc_manager(self, unwrapped_model, resume_from_checkpoint=None): if resume_from_checkpoint is not None: path = _add_variant(PADDLE_OPTIMIZER_NAME, self.args.optimizer_name_suffix) path = os.path.join(resume_from_checkpoint, path).replace("optimizer", "ema") + if self.args.zcc_save_ema_coef is not None and self.sharding_io is not None: + success, err_msg = self.sharding_io.check_same_strategy(resume_from_checkpoint) + else: + success, err_msg = True, None if self.args.zcc_save_ema_coef is not None and self.sharding_io is not None: success, err_msg = self.sharding_io.check_same_strategy(resume_from_checkpoint) else: @@ -822,6 +826,11 @@ def create_zcc_manager(self, unwrapped_model, resume_from_checkpoint=None): self.zcc_manager.set_ema_state_dict(path) else: logger.info(f"ZCC EMA does not load {path} because {err_msg}") + if success: + logger.info(f"ZCC EMA load from {path}") + self.zcc_manager.set_ema_state_dict(path) + else: + logger.info(f"ZCC EMA does not load {path} because {err_msg}") else: logger.info(f"ZCC EMA state dict not found, in: {path}") @@ -929,13 +938,13 @@ def train( self._memory_tracker.start() if not self.args.enable_auto_parallel: - if not self.args.should_load_sharding_stage1_model: + if not self.args.should_load_sharding_stage1_model and not self.args.load_flex_checkpoint: self._load_from_checkpoint(resume_from_checkpoint) if self.args.should_load_sharding_stage1_model: model = self._wrap_model_and_load_sharded_checkpoint(resume_from_checkpoint) - elif self.args.should_save_sharding_stage1_model: + elif self.args.should_save_sharding_stage1_model and not self.args.load_flex_checkpoint: # In the non-sharded mode, should invoke _load_from_checkpoint before _wrap_model. # In this mode, the rank0 load all params and the _wrap_model implicitly broadcast params from rank0 to the other ranks. model = self._wrap_model(self.model_wrapped) @@ -949,13 +958,43 @@ def train( if delay_optimizer_creation: self.create_optimizer_and_scheduler(num_training_steps=max_steps) self._load_optimizer_and_scheduler(resume_from_checkpoint) + elif self.args.load_flex_checkpoint: + model = self._wrap_model(self.model_wrapped) + if model is not self.model: + self.model_wrapped = model + if delay_optimizer_creation: + self.create_optimizer_and_scheduler(num_training_steps=max_steps) + + if resume_from_checkpoint is not None: + if not self.args.ignore_load_lr_and_optim: + model_sharded_state_dict = self.model.sharded_state_dict() + accessible_files = os.listdir(resume_from_checkpoint) + metadata_files = [file for file in accessible_files if file.endswith(".metadata")] + assert len(metadata_files) == 1, "Only support one metadata file now." + metadata = paddle.load(os.path.join(resume_from_checkpoint, metadata_files[0])) + state_dict_metadata = metadata.state_dict_metadata + init_optimizer(self.optimizer, model_sharded_state_dict, state_dict_metadata) + optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict) + sharded_state_dict = {**model_sharded_state_dict, **optimizer_sharded_state_dict} + dist.load_state_dict( + sharded_state_dict, resume_from_checkpoint, aoa_config=self.args.aoa_config + ) + self._load_scheduler(resume_from_checkpoint) + else: + model_sharded_state_dict = self.model.sharded_state_dict() + sharded_state_dict = model_sharded_state_dict + dist.load_state_dict( + sharded_state_dict, resume_from_checkpoint, aoa_config=self.args.aoa_config + ) else: model = self._wrap_model(self.model_wrapped) # for the rest of this function `model` is the outside model, whether it was wrapped or not if model is not self.model: self.model_wrapped = model + if delay_optimizer_creation: self.create_optimizer_and_scheduler(num_training_steps=max_steps) + self._load_optimizer_and_scheduler(resume_from_checkpoint) else: model = self.model_wrapped @@ -1357,6 +1396,7 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg): logger.warning( f"optimizer not run, scale_before: {scale_before_value[0]}, scale_after: {scale_after_value[0]}" ) + elif isinstance(self.optimizer, HybridParallelOptimizer): self.optimizer._step(parameters_list) else: @@ -1993,7 +2033,6 @@ def apply_decay_param_fun(x): grad_clip=nn.ClipGradByGlobalNorm(self.args.max_grad_norm) if self.args.max_grad_norm > 0 else None, **optimizer_kwargs, ) - return self.optimizer def _apply_to_optimizer(self, action): @@ -2033,6 +2072,13 @@ def _load_rng_state(self, checkpoint): return rng_file = os.path.join(checkpoint, f"rng_state_{dist.get_rank()}.pth") + if not os.path.isfile(rng_file): + logger.info( + "Didn't find an RNG file, if you are resuming a training that was launched in a distributed " + "fashion, reproducibility is not guaranteed." + ) + return + rng_file = os.path.join(checkpoint, f"rng_state_{dist.get_rank()}.pth") if not os.path.isfile(rng_file): logger.info( "Didn't find an RNG file, if you are resuming a training that was launched in a distributed " @@ -2238,7 +2284,6 @@ def _wrap_model(self, model, training=True): mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype) assert self.optimizer is not None, "optimizer is empty!" self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer) - # Pipeline mode if in_pipeline_parallel_mode: if self.args.amp_master_grad: @@ -2288,7 +2333,6 @@ def get_expected_keys(inputs, keys): if self.args.amp_master_grad: self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer) self.optimizer = fleet.distributed_optimizer(self.optimizer) - if ( hasattr(self.args, "enable_sharding_comm_overlap") and self.args.enable_sharding_comm_overlap @@ -2296,7 +2340,6 @@ def get_expected_keys(inputs, keys): and "split_param" in split_parallel_config(self.args.sharding_parallel_config) ): model.register_sharding_comm_overlap_hook(self.optimizer) - # No pipeline mode, sharding only if not in_pipeline_parallel_mode and in_sharding_parallel_mode: # Sharded DDP! @@ -2310,7 +2353,6 @@ def get_expected_keys(inputs, keys): model = paddle.distributed.fleet.meta_parallel.TensorParallel( model, hcg, strategy=fleet.fleet._user_defined_strategy ) - if ShardingOption.SHARD_OP in self.args.sharding: if self.args.amp_master_grad: mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype) # return value has no use @@ -2352,6 +2394,7 @@ def get_expected_keys(inputs, keys): offload=cpu_offload, **extra_kwargs, ) + if ShardingOption.SHARD_GRAD_OP in self.args.sharding and self.args.amp_master_grad: assert hasattr(optimizer, "use_main_grad"), ( "Current installed paddle doesn't support sharding stage 2 with main grad, " @@ -2377,7 +2420,6 @@ def get_expected_keys(inputs, keys): if self.args.amp_master_grad: self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer) self.optimizer = fleet.distributed_optimizer(self.optimizer) - # stage1 has v1 and v2 version if in_sharding_parallel_mode and ShardingOption.SHARD_OP in self.args.sharding: if "split_param" in self.args.sharding_parallel_config: @@ -2720,6 +2762,10 @@ def _save_checkpoint(self, model, metrics=None): else: self.save_model(output_dir) + if self.args.save_flex_checkpoint: + model_sharded_state_dict = self.model.sharded_state_dict() + os.makedirs(output_dir, exist_ok=True) + # Determine the new best metric / best model checkpoint if metrics is not None and self.args.metric_for_best_model is not None: metric_to_check = self.args.metric_for_best_model @@ -2779,23 +2825,38 @@ def _save_checkpoint(self, model, metrics=None): signal_dir, ) else: - if self.dp_group.rank > 0: # this should only work for MoE saving - self._save_ckpt_func( - self._filter_moe_no_sync_optimizer_params(), - os.path.join(output_dir, optimizer_name), - saved_signal_path, + if self.args.save_flex_checkpoint: + optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict) + dist.save_state_dict( + {**model_sharded_state_dict, **optimizer_sharded_state_dict}, + output_dir, ) - + if self.args.should_save: + if self.tokenizer is not None and self.args.save_tokenizer: + self.tokenizer.save_pretrained(output_dir) + # Good practice: save your training arguments together with the trained model + paddle.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) else: - state_dict = self.optimizer.state_dict() - save_path = os.path.join(output_dir, optimizer_name) - if self.args.use_async_save: - assert not strtobool(os.getenv("FLAG_LLM_PDC", "False")), "Dont support FLAG_LLM_PDC" - self._async_optimizer_saver.run( - state_dict, save_path, saved_signal_path=saved_signal_path + if self.dp_group.rank > 0: # this should only work for MoE saving + self._save_ckpt_func( + self._filter_moe_no_sync_optimizer_params(), + os.path.join(output_dir, optimizer_name), + saved_signal_path, ) + else: - self._save_ckpt_func(state_dict, save_path, saved_signal_path) + state_dict = self.optimizer.state_dict() + save_path = os.path.join(output_dir, optimizer_name) + if self.args.use_async_save: + assert not strtobool( + os.getenv("FLAG_LLM_PDC", "False") + ), "Dont support FLAG_LLM_PDC" + self._async_optimizer_saver.run( + state_dict, save_path, saved_signal_path=saved_signal_path + ) + else: + self._save_ckpt_func(state_dict, save_path, saved_signal_path) + else: if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config: global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1 @@ -2806,7 +2867,12 @@ def _save_checkpoint(self, model, metrics=None): or "remove_master_weight" not in self.args.unified_checkpoint_config ): paddle.save(global_rank, os.path.join(signal_dir, f".master_weight.done.{global_rank}")) - if self.args.should_save or self.args.use_expert_parallel: + + if ( + self.args.should_save + or self.args.use_expert_parallel + or (self.args.data_parallel_degree > 1 and self.args.save_flex_checkpoint) + ): if not self.args.use_hybrid_parallel: logger.info("Saving optimizer files.") if self.args.unified_checkpoint: @@ -2816,6 +2882,17 @@ def _save_checkpoint(self, model, metrics=None): output_dir, signal_dir, ) + elif self.args.save_flex_checkpoint: + optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict) + dist.save_state_dict( + {**model_sharded_state_dict, **optimizer_sharded_state_dict}, + output_dir, + ) + if self.args.should_save: + if self.tokenizer is not None and self.args.save_tokenizer: + self.tokenizer.save_pretrained(output_dir) + # Good practice: save your training arguments together with the trained model + paddle.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) else: if self.args.data_parallel_rank > 0 and self.args.use_expert_parallel: self._save_ckpt_func( @@ -2849,7 +2926,17 @@ def _save_checkpoint(self, model, metrics=None): if self.args.unified_checkpoint and (self.args.offload_optim or self.args.tensorwise_offload_optimizer): self._offload_optimizer() - + else: + if self.args.save_flex_checkpoint: + dist.save_state_dict( + model_sharded_state_dict, + output_dir, + ) + if self.args.should_save: + if self.tokenizer is not None and self.args.save_tokenizer: + self.tokenizer.save_pretrained(output_dir) + # Good practice: save your training arguments together with the trained model + paddle.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) self.runtime_timer.stop() # Maybe delete some older checkpoints. @@ -3064,6 +3151,7 @@ def _save( else: if isinstance(self.model, PretrainedModel) and self.args.should_save_sharding_stage1_model: config_to_save = None + self.sharding_io.set_optimizer(self.optimizer) state_dict, config_to_save, weight_name_suffix = self.sharding_io.manipulate_state_dict_and_config( self.model, merge_tensor_parallel=merge_tensor_parallel ) @@ -3093,6 +3181,24 @@ def _save( with open(path, "w") as f: json.dump(model_meta, f) + def _load_scheduler(self, checkpoint): + if checkpoint is None: + self.runtime_timer.stop() + return + + if not self.args.ignore_load_lr_and_optim: + if distributed_isfile(os.path.join(checkpoint, SCHEDULER_NAME)): + self.lr_scheduler.set_state_dict( + paddle.load(distributed_file(os.path.join(checkpoint, SCHEDULER_NAME))) + ) + else: + raise ValueError(f"scheduler-file not found, scheduler:{os.path.join(checkpoint, SCHEDULER_NAME)}") + + if self.do_grad_scaling and distributed_isfile(os.path.join(checkpoint, SCALER_NAME)): + self.scaler.load_state_dict( + paddle.load(distributed_file(os.path.join(checkpoint, SCALER_NAME)), return_numpy=True) + ) + def _load_optimizer_and_scheduler(self, checkpoint): """If optimizer and scheduler states exist, load them.""" self.runtime_timer.start("checkpoint loading time") @@ -3134,6 +3240,7 @@ def _load_optimizer_and_scheduler(self, checkpoint): and "split_param" in split_parallel_config(self.args.sharding_parallel_config) ): model = self.model_wrapped + opt_state_dict = self.unified_checkpoint_handler.load_unified_optimizer( model=model, optimizer=self.optimizer, @@ -3165,18 +3272,7 @@ def _load_optimizer_and_scheduler(self, checkpoint): optimizer_name = _add_variant(PADDLE_OPTIMIZER_NAME, self.args.optimizer_name_suffix) raise ValueError(f"optimizer-state-dict not found, opt: {os.path.join(checkpoint, optimizer_name)}.") - if not self.args.ignore_load_lr_and_optim: - if distributed_isfile(os.path.join(checkpoint, SCHEDULER_NAME)): - self.lr_scheduler.set_state_dict( - paddle.load(distributed_file(os.path.join(checkpoint, SCHEDULER_NAME))) - ) - else: - raise ValueError(f"scheduler-file not found, scheduler:{os.path.join(checkpoint, SCHEDULER_NAME)}") - - if self.do_grad_scaling and distributed_isfile(os.path.join(checkpoint, SCALER_NAME)): - self.scaler.load_state_dict( - paddle.load(distributed_file(os.path.join(checkpoint, SCALER_NAME)), return_numpy=True) - ) + self._load_scheduler(checkpoint) if self.args.offload_optim: logger.info("Offloading optimizer state...") diff --git a/paddlenlp/trainer/trainer_utils.py b/paddlenlp/trainer/trainer_utils.py index d8d88d1cd4ad..0b9fa9ea5c16 100644 --- a/paddlenlp/trainer/trainer_utils.py +++ b/paddlenlp/trainer/trainer_utils.py @@ -37,6 +37,10 @@ import paddle import paddle.distributed as dist from paddle.distributed import fleet +from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import ( + DygraphShardingOptimizer, + DygraphShardingOptimizerV2, +) from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker from paddle.io import IterableDataset from paddle.optimizer.lr import LambdaDecay @@ -1357,3 +1361,87 @@ def set_comm_config(configs, attr, dict_obj): set_comm_config("moe_sharding_configs", "check_nccl_config", nccl_config.get("moe_sharding_check", None)) set_comm_config("default_comm_group_configs", "nccl_config", nccl_config.get("default", None)) return strategy + + +def init_optimizer(optimizer, model_sharded_state_dict, state_dict_metadata): + """ + Initialize the optimizer's states according to its type. + + For DygraphShardingOptimizer (V1), initializes accumulators for local parameters. + For DygraphShardingOptimizerV2, manually initializes master weights and state dict for sharded parameters. + For other cases, initializes accumulators for all parameters. + + Args: + optimizer: The optimizer instance to be initialized. + """ + optimizer_state_names = [".moment1_0", ".moment2_0", ".beta1_pow_acc_0", ".beta2_pow_acc_0", ".w_0"] + inner_opt = getattr(optimizer, "_inner_opt", None) + static_to_struct_mapping = {} + model_sharded_state_dict = dict(sorted(model_sharded_state_dict.items())) + for k, v in model_sharded_state_dict.items(): + if v.local_tensor.name not in static_to_struct_mapping: + static_to_struct_mapping[v.local_tensor.name] = k + + if isinstance(inner_opt, DygraphShardingOptimizer): + local_params = optimizer._rank2params[optimizer._sharding_rank] + param_list = [] + for param in local_params: + param_name = param.name + struct_name = static_to_struct_mapping[param_name] + if not any(struct_name + state_name in state_dict_metadata for state_name in optimizer_state_names): + continue + param_list.append(param) + optimizer._create_accumulators(paddle.base.framework.default_main_program().global_block(), param_list) + return + + elif isinstance(inner_opt, DygraphShardingOptimizerV2): + + def init_param_optimizer_states(param_iter): + master_weights = {} + state_dict = {} + moments = ("moment1_0", "moment2_0") + betas = ("beta1_pow_acc_0", "beta2_pow_acc_0") + for static_name, shape, no_need_master_weights in param_iter: + if not no_need_master_weights: + master_weights[static_name] = paddle.zeros(shape, dtype="float32") + prefix = f"{static_name}_fp32_master_0_" + else: + prefix = f"{static_name}_" + + for moment in moments: + key = f"{prefix}{moment}" + state_dict[key] = paddle.zeros(shape, dtype="float32") + for beta in betas: + key = f"{prefix}{beta}" + state_dict[key] = paddle.zeros((1,), dtype="float32") + return master_weights, state_dict + + def buffer_params(): + for buffer in optimizer._comm_buffer_list: + for param_name, grad_view in buffer._sharding_param_grad_view.items(): + struct_name = static_to_struct_mapping[param_name] + if not any( + struct_name + state_name in state_dict_metadata for state_name in optimizer_state_names + ): + continue + param_begin = grad_view._param_begin + param_end = grad_view._param_end + shape = (param_end - param_begin,) + no_need_master_weights = grad_view._param.dtype == paddle.float32 + + if shape[0] > 0: + yield param_name, shape, no_need_master_weights + + master_weights, state_dict = init_param_optimizer_states(buffer_params()) + state_dict["master_weights"] = master_weights + state_dict["LR_Scheduler"] = {"last_epoch": 1, "last_lr": 5e-06} + optimizer.set_state_dict(state_dict) + return + param_list = [] + for param in optimizer._parameter_list: + param_name = param.name + struct_name = static_to_struct_mapping[param_name] + if not any(struct_name + state_name in state_dict_metadata for state_name in optimizer_state_names): + continue + param_list.append(param) + optimizer._create_accumulators(paddle.base.framework.default_main_program().global_block(), param_list) diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 88da91adced6..ed4af3953e93 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -407,6 +407,12 @@ class TrainingArguments: Whether to release gradients during training. Default is `False`. ckpt_quant_stage (`str`, *optional*): Whether activate checkpoint quantization. O0: deactivate, O1: Int8 compression, O2: Int4 compression. (default: O0). + save_checkpoint_mode (`str`, *optional*): + Specifies the method for saving checkpoints. Options are: None, 'sharding_io', 'unified_checkpoint', 'flex_checkpoint', and 'safetensor'. (default: None). This setting is ignored if the corresponding switch is configured. + load_checkpoint_mode (`str`, *optional*): + Specifies the method for loading checkpoints. Options are: None, 'sharding_io', 'unified_checkpoint', 'flex_checkpoint', and 'safetensor'. (default: None). This setting is ignored if the corresponding switch is configured. + aoa_config (`Optional[dict[str, list[str]]]`, *optional*): + The AoA configuration of FlexCheckpoint, used to describe the mapping between model weights and the checkpoint content. Default is None. """ output_dir: str = field( @@ -941,6 +947,29 @@ class TrainingArguments: default=False, metadata={"help": "Whether to use async_save instead of paddle.save."}, ) + save_checkpoint_mode: Optional[str] = field( + default=None, + metadata={ + "help": ( + "Specifies the method used to save checkpoints. " + "Available options: 'sharding_io', 'unified_checkpoint', " + "'flex_checkpoint', 'safetensor'." + "This setting is ignored if the corresponding switch is configured." + ) + }, + ) + + load_checkpoint_mode: Optional[str] = field( + default=None, + metadata={ + "help": ( + "Specifies the method used to load checkpoints. " + "Available options: 'sharding_io', 'unified_checkpoint', " + "'flex_checkpoint', 'safetensor'." + "This setting is ignored if the corresponding switch is configured." + ) + }, + ) ordered_save_group_size: int = field( default=0, metadata={ @@ -1106,6 +1135,13 @@ class TrainingArguments: default=None, metadata={"help": "NCCL中通信组的细粒度控制的配置文件路径, 默认值为None, 代表不启用此项配置"} ) + aoa_config: Optional[dict[str, list[str]]] = field( + default=None, + metadata={ + "help": "The AoA configuration of FlexCheckpoint, used to describe the mapping between model weights and the checkpoint content. Default is None." + }, + ) + def __post_init__(self): world_size = paddle.distributed.get_world_size() if in_auto_parallel_align_mode(): @@ -1205,7 +1241,8 @@ def __post_init__(self): raise ValueError("AdamW Mini currently doesn't support tensor parallelism.") self._post_init_parallel_degree() - + self._post_init_save_checkpoint_mode() + self._post_init_load_checkpoint_mode() if self.to_static: assert world_size == 1 or self.enable_auto_parallel, ( "It's not supported for training in static mode except the following cases : " @@ -1863,6 +1900,9 @@ def is_context_parallel_supported(): # DP use hybrid group strategy = fleet.DistributedStrategy() fleet.init(is_collective=True, strategy=strategy) + elif self.save_flex_checkpoint or self.load_flex_checkpoint: + strategy = fleet.DistributedStrategy() + fleet.init(is_collective=True, strategy=strategy) else: paddle.distributed.init_parallel_env() @@ -2129,6 +2169,65 @@ def _post_init_parallel_degree(self): if self.use_hybrid_parallel and self.enable_auto_parallel: self.use_hybrid_parallel = False + def _post_init_save_checkpoint_mode(self): + self.save_flex_checkpoint = False + + if not self.save_checkpoint_mode: + return + + # Ensure that only one checkpoint mode is set at a time + if self.unified_checkpoint or self.save_sharded_model: + return + + valid_modes = ["unified_checkpoint", "sharding_io", "safetensor", "flex_checkpoint"] + assert ( + self.save_checkpoint_mode in valid_modes + ), f"Invalid save_checkpoint_mode: {self.save_checkpoint_mode}, Only these modes are allowed: {valid_modes}." + + if self.save_checkpoint_mode == "safetensor": + raise NotImplementedError("safetensor checkpoint saving is not implemented yet.") + elif self.save_checkpoint_mode == "unified_checkpoint": + assert ( + getattr(self, "load_checkpoint_mode", None) == "unified_checkpoint" + ), "When saving in unified_checkpoint mode, load_checkpoint_mode must also be 'unified_checkpoint'." + self.unified_checkpoint = True + elif self.save_checkpoint_mode == "sharding_io": + self.save_sharded_model = True + elif self.save_checkpoint_mode == "flex_checkpoint": + self.save_flex_checkpoint = True + else: + raise NotImplementedError(f"Checkpoint mode '{self.save_checkpoint_mode}' is not supported.") + + def _post_init_load_checkpoint_mode(self): + + self.load_flex_checkpoint = False + + if not self.load_checkpoint_mode: + return + + # Ensure that only one checkpoint mode is set at a time + if self.unified_checkpoint or self.load_sharded_model: + return + + valid_modes = ["unified_checkpoint", "sharding_io", "safetensor", "flex_checkpoint"] + assert ( + self.load_checkpoint_mode in valid_modes + ), f"Invalid load_checkpoint_mode: {self.load_checkpoint_mode}, Only these modes are allowed: {valid_modes}." + + if self.load_checkpoint_mode == "safetensor": + raise NotImplementedError("safetensor checkpoint loading is not implemented yet.") + elif self.load_checkpoint_mode == "unified_checkpoint": + assert ( + getattr(self, "save_checkpoint_mode", None) == "unified_checkpoint" + ), "When loading in unified_checkpoint mode, save_checkpoint_mode must also be 'unified_checkpoint'." + self.unified_checkpoint = True + elif self.load_checkpoint_mode == "sharding_io": + self.load_sharded_model = True + elif self.load_checkpoint_mode == "flex_checkpoint": + self.load_flex_checkpoint = True + else: + raise NotImplementedError(f"Checkpoint mode '{self.load_checkpoint_mode}' is not supported.") + def add_moe_comm_group(self): hybrid_configs = fleet.fleet._user_defined_strategy.hybrid_configs hcg = fleet.get_hybrid_communicate_group() @@ -2457,6 +2556,8 @@ def should_save_model_state(self): return True elif self.enable_auto_parallel: return True + elif self.save_flex_checkpoint: + return False elif self.use_hybrid_parallel: # save on dataset rank 0 return self.sharding_parallel_rank == 0 and (self.data_parallel_rank == 0 or self.use_expert_parallel) diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 84c9bcc7ff4e..18c48f5470a8 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -30,6 +30,9 @@ from paddle.distributed import fleet from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker from paddle.distributed.fleet.recompute.recompute import recompute +from paddle.distributed.flex_checkpoint.dcp.sharded_weight import ( + build_sharded_state_dict, +) from paddlenlp.transformers.refined_recompute import ( RRColumnParallelLinear, @@ -1367,7 +1370,6 @@ def _get_name_mappings(cls, config: LlamaConfig) -> list[StateDictNameMapping]: @classmethod def _get_tensor_parallel_mappings(cls, config: LlamaConfig, is_split=True): - from paddlenlp.transformers.conversion_utils import split_or_merge_func fn = split_or_merge_func( @@ -1995,6 +1997,14 @@ def forward(self, hidden_states, tensor_parallel_output=None): ) return logits + def sharded_state_dict( + self, + structured_name_prefix: str = "", + ): + axis = 0 if self.transpose_y else 1 + state_dict = self.state_dict(structured_name_prefix="") + return build_sharded_state_dict(state_dict, {"weight": axis}, structured_name_prefix) + class LlamaForCausalLM(LlamaPretrainedModel): enable_to_static_method = True diff --git a/paddlenlp/transformers/model_utils.py b/paddlenlp/transformers/model_utils.py index b5756896c65a..b478b253835f 100644 --- a/paddlenlp/transformers/model_utils.py +++ b/paddlenlp/transformers/model_utils.py @@ -3167,6 +3167,19 @@ def state_dict(self, *args, **kwargs): return state_dict + def sharded_state_dict(self, *args, **kwargs): + sharded_state_dict = super().sharded_state_dict(*args, **kwargs) + if self._single_to_pp_mapping is None: + self._set_pipeline_name_mapping() + assert len(self._single_to_pp_mapping) > 0, "The pipeline stage must have parameters!" + + for k in list(sharded_state_dict.keys()): + v = sharded_state_dict.pop(k) + v.key = self._pp_to_single_mapping[k] + sharded_state_dict[self._pp_to_single_mapping[k]] = v + + return sharded_state_dict + def set_state_dict(self, state_dict, *args, **kwargs): if self._single_to_pp_mapping is None: self._set_pipeline_name_mapping() From d0d67a8133609cfaca3e3fbef720196fb8532c61 Mon Sep 17 00:00:00 2001 From: zty-king <17786324919@163.com> Date: Sat, 20 Sep 2025 17:14:40 +0000 Subject: [PATCH 2/6] Consistently use save_checkpoint_format and load_checkpoint_format --- paddlenlp/trainer/trainer.py | 97 +++++++++++++++------- paddlenlp/trainer/training_args.py | 127 ++++++++++++----------------- 2 files changed, 118 insertions(+), 106 deletions(-) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index f5fbac5bf27f..87472ff518e5 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -381,7 +381,10 @@ def __init__( is_ema=self.args.sharded_model_from_ema, ) - if self.args.unified_checkpoint: + if ( + self.args.save_checkpoint_format == "unified_checkpoint" + or self.args.load_checkpoint_format == "unified_checkpoint" + ): self.unified_checkpoint_handler = UnifiedCheckpointHandler(self.args) if self.sharding is not None and self.optimizer is not None: @@ -435,8 +438,9 @@ def _save_ckpt_func(state_dict, path, signal_path=None): not self.args.ignore_save_lr_and_optim ), "ignore_save_lr_and_optim should be False when using zero cost checkpoint" assert self.args.use_hybrid_parallel, "use_hybrid_parallel must be True when using zero cost checkpoint" - assert ( - not self.args.unified_checkpoint + assert not ( + self.args.save_checkpoint_format == "unified_checkpoint" + or self.args.load_checkpoint_format == "unified_checkpoint" ), "use_unified_checkpoint should be False when using zero cost checkpoint" assert not strtobool( os.getenv("FLAG_LLM_PDC", "False") @@ -474,7 +478,10 @@ def _save_ckpt_func(state_dict, path, signal_path=None): or isinstance(self.model, LoKrModel) or isinstance(self.model, ReFTModel) ): - if self.args.unified_checkpoint and "skip_save_model_weight" in self.args.unified_checkpoint_config: + if ( + self.args.load_checkpoint_format == "unified_checkpoint" + and "skip_save_model_weight" in self.args.unified_checkpoint_config + ): self.args.unified_checkpoint_config.remove("skip_save_model_weight") logger.warning( "We do not support skip_save_model_weight in peft model when using unified checkpoint, remove this config." @@ -658,14 +665,17 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None): # Load potential model checkpoint if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint: - uc_async_save = self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config + uc_async_save = ( + self.args.load_checkpoint_format == "unified_checkpoint" + and "async_save" in self.args.unified_checkpoint_config + ) resume_from_checkpoint = get_last_checkpoint( self.args.output_dir, signal_folder=self.args.output_signal_dir, uc_async_save=uc_async_save ) if resume_from_checkpoint is None: raise ValueError(f"No valid checkpoint found in output directory ({self.args.output_dir})") - if self.args.unified_checkpoint: + if self.args.load_checkpoint_format == "unified_checkpoint": if resume_from_checkpoint is not None: use_unified_checkpoint = False if self.is_unified_checkpoint(resume_from_checkpoint): @@ -938,13 +948,18 @@ def train( self._memory_tracker.start() if not self.args.enable_auto_parallel: - if not self.args.should_load_sharding_stage1_model and not self.args.load_flex_checkpoint: + if ( + not self.args.should_load_sharding_stage1_model + and not self.args.load_checkpoint_format == "flex_checkpoint" + ): self._load_from_checkpoint(resume_from_checkpoint) if self.args.should_load_sharding_stage1_model: model = self._wrap_model_and_load_sharded_checkpoint(resume_from_checkpoint) - elif self.args.should_save_sharding_stage1_model and not self.args.load_flex_checkpoint: + elif self.args.should_save_sharding_stage1_model and not ( + self.args.load_checkpoint_format == "flex_checkpoint" + ): # In the non-sharded mode, should invoke _load_from_checkpoint before _wrap_model. # In this mode, the rank0 load all params and the _wrap_model implicitly broadcast params from rank0 to the other ranks. model = self._wrap_model(self.model_wrapped) @@ -958,7 +973,7 @@ def train( if delay_optimizer_creation: self.create_optimizer_and_scheduler(num_training_steps=max_steps) self._load_optimizer_and_scheduler(resume_from_checkpoint) - elif self.args.load_flex_checkpoint: + elif self.args.load_checkpoint_format == "flex_checkpoint": model = self._wrap_model(self.model_wrapped) if model is not self.model: self.model_wrapped = model @@ -1478,7 +1493,10 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg): logger.info("\nTraining completed. \n") # unlink shared_memory if used. - if self.args.unified_checkpoint: + if ( + self.args.save_checkpoint_format == "unified_checkpoint" + or self.args.load_checkpoint_format == "unified_checkpoint" + ): self.unified_checkpoint_handler.unlink_shared_memory() if args.load_best_model_at_end and self.state.best_model_checkpoint is not None: @@ -1491,7 +1509,10 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg): if isinstance(self.model, LoRAModel) or isinstance(self.model, PrefixModelForCausalLM): self._load_best_model_from_peft_checkpoint() else: - if self.args.unified_checkpoint: + if ( + self.args.save_checkpoint_format == "unified_checkpoint" + or self.args.load_checkpoint_format == "unified_checkpoint" + ): self.unified_checkpoint_handler.load_unified_checkpoint( self.model, self.state.best_model_checkpoint, @@ -1541,7 +1562,7 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg): return TrainOutput(self.state.global_step, train_loss, metrics) def _load_best_model_from_peft_checkpoint(self): - if self.args.unified_checkpoint: + if self.args.load_checkpoint_format == "unified_checkpoint": self.unified_checkpoint_handler.load_unified_checkpoint( self.model, self.state.best_model_checkpoint, @@ -2088,7 +2109,7 @@ def _load_rng_state(self, checkpoint): checkpoint_rng_state = paddle.load(rng_file, return_numpy=True) if checkpoint_rng_state.get("world_size", None) != self.args.world_size: - logger.warn("Cannot load rng states when changing world size of training job.") + logger.warning("Cannot load rng states when changing world size of training job.") return random.setstate(checkpoint_rng_state["python"]) @@ -2336,7 +2357,8 @@ def get_expected_keys(inputs, keys): if ( hasattr(self.args, "enable_sharding_comm_overlap") and self.args.enable_sharding_comm_overlap - and self.args.unified_checkpoint + and self.args.load_checkpoint_format == "unified_checkpoint" + or self.args.load_checkpoint_format == "unified_checkpoint" and "split_param" in split_parallel_config(self.args.sharding_parallel_config) ): model.register_sharding_comm_overlap_hook(self.optimizer) @@ -2671,7 +2693,10 @@ def save_model( if self.args.should_save_model_state: self._save(output_dir=output_dir, merge_tensor_parallel=merge_tensor_parallel) else: - if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config: + if ( + self.args.save_checkpoint_format == "unified_checkpoint" + and "async_save" in self.args.unified_checkpoint_config + ): os.makedirs(signal_dir, exist_ok=True) if self.is_in_train: global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1 @@ -2687,7 +2712,7 @@ def save_model( # For ckpt integrity paddle.save(self.state.global_step, os.path.join(output_dir, ".model_done")) if ( - self.args.unified_checkpoint + self.args.load_checkpoint_format == "unified_checkpoint" and "async_save" in self.args.unified_checkpoint_config and not self.is_in_train ): @@ -2762,7 +2787,7 @@ def _save_checkpoint(self, model, metrics=None): else: self.save_model(output_dir) - if self.args.save_flex_checkpoint: + if self.args.save_checkpoint_format == "flex_checkpoint": model_sharded_state_dict = self.model.sharded_state_dict() os.makedirs(output_dir, exist_ok=True) @@ -2810,14 +2835,16 @@ def _save_checkpoint(self, model, metrics=None): optimizer_name = _add_variant(PADDLE_OPTIMIZER_NAME, self.args.optimizer_name_suffix) saved_signal_path = os.path.join(output_dir, f"saved_signal_{dist.get_rank()}") - if self.args.unified_checkpoint and (self.args.offload_optim or self.args.tensorwise_offload_optimizer): + if self.args.save_checkpoint_format == "unified_checkpoint" and ( + self.args.offload_optim or self.args.tensorwise_offload_optimizer + ): self._reload_optimizer() if self.args.use_hybrid_parallel: if self.dp_group.rank <= 0 or self.args.use_expert_parallel: os.makedirs(output_dir, exist_ok=True) logger.info("Saving optimizer files.") - if self.args.unified_checkpoint: + if self.args.save_checkpoint_format == "unified_checkpoint": self.unified_checkpoint_handler.save_unified_optimizer( self.model, self.optimizer, @@ -2825,7 +2852,7 @@ def _save_checkpoint(self, model, metrics=None): signal_dir, ) else: - if self.args.save_flex_checkpoint: + if self.args.save_checkpoint_format == "flex_checkpoint": optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict) dist.save_state_dict( {**model_sharded_state_dict, **optimizer_sharded_state_dict}, @@ -2858,7 +2885,10 @@ def _save_checkpoint(self, model, metrics=None): self._save_ckpt_func(state_dict, save_path, saved_signal_path) else: - if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config: + if ( + self.args.save_checkpoint_format == "unified_checkpoint" + and "async_save" in self.args.unified_checkpoint_config + ): global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1 os.makedirs(signal_dir, exist_ok=True) paddle.save(global_rank, os.path.join(signal_dir, f".optimizer_weight.done.{global_rank}")) @@ -2871,18 +2901,18 @@ def _save_checkpoint(self, model, metrics=None): if ( self.args.should_save or self.args.use_expert_parallel - or (self.args.data_parallel_degree > 1 and self.args.save_flex_checkpoint) + or (self.args.data_parallel_degree > 1 and self.args.save_checkpoint_format == "flex_checkpoint") ): if not self.args.use_hybrid_parallel: logger.info("Saving optimizer files.") - if self.args.unified_checkpoint: + if self.args.save_checkpoint_format == "unified_checkpoint": self.unified_checkpoint_handler.save_unified_optimizer( self.model, self.optimizer, output_dir, signal_dir, ) - elif self.args.save_flex_checkpoint: + elif self.args.save_checkpoint_format == "flex_checkpoint": optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict) dist.save_state_dict( {**model_sharded_state_dict, **optimizer_sharded_state_dict}, @@ -2913,7 +2943,7 @@ def _save_checkpoint(self, model, metrics=None): if self.do_grad_scaling: paddle.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) else: - if self.args.unified_checkpoint and not self.args.use_hybrid_parallel: + if self.args.save_checkpoint_format == "unified_checkpoint" and not self.args.use_hybrid_parallel: if "async_save" in self.args.unified_checkpoint_config: global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1 os.makedirs(signal_dir, exist_ok=True) @@ -2924,10 +2954,12 @@ def _save_checkpoint(self, model, metrics=None): ): paddle.save(global_rank, os.path.join(signal_dir, f".master_weight.done.{global_rank}")) - if self.args.unified_checkpoint and (self.args.offload_optim or self.args.tensorwise_offload_optimizer): + if self.args.save_checkpoint_format == "unified_checkpoint" and ( + self.args.offload_optim or self.args.tensorwise_offload_optimizer + ): self._offload_optimizer() else: - if self.args.save_flex_checkpoint: + if self.args.save_checkpoint_format == "flex_checkpoint": dist.save_state_dict( model_sharded_state_dict, output_dir, @@ -3039,7 +3071,10 @@ def _save( # signal_dir is used for asynchronous saving situations. signal_dir = self.args.output_signal_dir - if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config: + if ( + self.args.save_checkpoint_format == "unified_checkpoint" + and "async_save" in self.args.unified_checkpoint_config + ): if PREFIX_CHECKPOINT_DIR in os.path.split(output_dir)[-1]: signal_dir = os.path.join(signal_dir, os.path.split(output_dir)[-1]) os.makedirs(signal_dir, exist_ok=True) @@ -3051,7 +3086,7 @@ def _save( if ( strtobool(os.getenv("FLAG_LLM_PDC", "False")) and paddle.distributed.get_rank() == 0 - and self.args.unified_checkpoint + and self.args.save_checkpoint_format == "unified_checkpoint" and "async_save" in self.args.unified_checkpoint_config ): world_size = paddle.distributed.get_world_size() @@ -3074,7 +3109,7 @@ def _save( # Good practice: save your training arguments together with the trained model paddle.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) - if self.args.unified_checkpoint: + if self.args.save_checkpoint_format == "unified_checkpoint": unified_checkpoint_config_backup = self.args.unified_checkpoint_config # backup and remove unified_checkpoint_config for not trine stage if not self.is_in_train: @@ -3218,7 +3253,7 @@ def _load_optimizer_and_scheduler(self, checkpoint): ) else: use_unified_checkpoint = False - if self.args.unified_checkpoint: + if self.args.load_checkpoint_format == "unified_checkpoint": if self.is_unified_checkpoint(checkpoint): use_unified_checkpoint = True else: diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index ed4af3953e93..092b7e08e355 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -407,10 +407,10 @@ class TrainingArguments: Whether to release gradients during training. Default is `False`. ckpt_quant_stage (`str`, *optional*): Whether activate checkpoint quantization. O0: deactivate, O1: Int8 compression, O2: Int4 compression. (default: O0). - save_checkpoint_mode (`str`, *optional*): - Specifies the method for saving checkpoints. Options are: None, 'sharding_io', 'unified_checkpoint', 'flex_checkpoint', and 'safetensor'. (default: None). This setting is ignored if the corresponding switch is configured. - load_checkpoint_mode (`str`, *optional*): - Specifies the method for loading checkpoints. Options are: None, 'sharding_io', 'unified_checkpoint', 'flex_checkpoint', and 'safetensor'. (default: None). This setting is ignored if the corresponding switch is configured. + save_checkpoint_format (`str`, *optional*): + Specifies the format for saving checkpoints. Options are: None, 'sharding_io', 'unified_checkpoint', 'flex_checkpoint', and 'safetensor'. (default: None). This setting is ignored if the corresponding switch is configured. + load_checkpoint_format (`str`, *optional*): + Specifies the format for loading checkpoints. Options are: None, 'sharding_io', 'unified_checkpoint', 'flex_checkpoint', and 'safetensor'. (default: None). This setting is ignored if the corresponding switch is configured. aoa_config (`Optional[dict[str, list[str]]]`, *optional*): The AoA configuration of FlexCheckpoint, used to describe the mapping between model weights and the checkpoint content. Default is None. """ @@ -947,11 +947,11 @@ class TrainingArguments: default=False, metadata={"help": "Whether to use async_save instead of paddle.save."}, ) - save_checkpoint_mode: Optional[str] = field( + save_checkpoint_format: Optional[str] = field( default=None, metadata={ "help": ( - "Specifies the method used to save checkpoints. " + "Specifies the format used to save checkpoints. " "Available options: 'sharding_io', 'unified_checkpoint', " "'flex_checkpoint', 'safetensor'." "This setting is ignored if the corresponding switch is configured." @@ -959,11 +959,11 @@ class TrainingArguments: }, ) - load_checkpoint_mode: Optional[str] = field( + load_checkpoint_format: Optional[str] = field( default=None, metadata={ "help": ( - "Specifies the method used to load checkpoints. " + "Specifies the format used to load checkpoints. " "Available options: 'sharding_io', 'unified_checkpoint', " "'flex_checkpoint', 'safetensor'." "This setting is ignored if the corresponding switch is configured." @@ -1241,8 +1241,8 @@ def __post_init__(self): raise ValueError("AdamW Mini currently doesn't support tensor parallelism.") self._post_init_parallel_degree() - self._post_init_save_checkpoint_mode() - self._post_init_load_checkpoint_mode() + self._post_init_save_checkpoint_format() + self._post_init_load_checkpoint_format() if self.to_static: assert world_size == 1 or self.enable_auto_parallel, ( "It's not supported for training in static mode except the following cases : " @@ -1896,27 +1896,31 @@ def is_context_parallel_supported(): else: if world_size > 1: if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(): - if self.unified_checkpoint: + if self.save_checkpoint_format in [ + "unified_checkpoint", + "flex_checkpoint", + ] or self.load_checkpoint_format in ["unified_checkpoint", "flex_checkpoint"]: # DP use hybrid group strategy = fleet.DistributedStrategy() fleet.init(is_collective=True, strategy=strategy) - elif self.save_flex_checkpoint or self.load_flex_checkpoint: - strategy = fleet.DistributedStrategy() - fleet.init(is_collective=True, strategy=strategy) else: paddle.distributed.init_parallel_env() if ( - self.unified_checkpoint + ( + self.save_checkpoint_format == "unified_checkpoint" + or self.load_checkpoint_format == "unified_checkpoint" + ) and self.sharding_parallel_degree > 0 and ShardingOption.FULL_SHARD in self.sharding ): logger.warning( - "Unified checkpoint currently do not support sharding stage3, set `unified_checkpoint` to False." + "Unified checkpoint currently do not support sharding stage3, disabling unified_checkpoint format." ) - self.unified_checkpoint = False + self.save_checkpoint_format = None + self.load_checkpoint_format = None - if self.unified_checkpoint: + if self.save_checkpoint_format == "unified_checkpoint" or self.load_checkpoint_format == "unified_checkpoint": unified_checkpoint_config = set(self.unified_checkpoint_config.split(" ")) if sys.platform.startswith("win") and "async_save" in self.unified_checkpoint_config: raise ValueError("Currently do not support asynchronous saving for Windows system!") @@ -2169,64 +2173,35 @@ def _post_init_parallel_degree(self): if self.use_hybrid_parallel and self.enable_auto_parallel: self.use_hybrid_parallel = False - def _post_init_save_checkpoint_mode(self): - self.save_flex_checkpoint = False - - if not self.save_checkpoint_mode: - return - - # Ensure that only one checkpoint mode is set at a time - if self.unified_checkpoint or self.save_sharded_model: - return - - valid_modes = ["unified_checkpoint", "sharding_io", "safetensor", "flex_checkpoint"] - assert ( - self.save_checkpoint_mode in valid_modes - ), f"Invalid save_checkpoint_mode: {self.save_checkpoint_mode}, Only these modes are allowed: {valid_modes}." - - if self.save_checkpoint_mode == "safetensor": - raise NotImplementedError("safetensor checkpoint saving is not implemented yet.") - elif self.save_checkpoint_mode == "unified_checkpoint": + def _post_init_save_checkpoint_format(self): + if self.save_checkpoint_format: + valid_modes = ["unified_checkpoint", "sharding_io", "safetensor", "flex_checkpoint"] assert ( - getattr(self, "load_checkpoint_mode", None) == "unified_checkpoint" - ), "When saving in unified_checkpoint mode, load_checkpoint_mode must also be 'unified_checkpoint'." - self.unified_checkpoint = True - elif self.save_checkpoint_mode == "sharding_io": - self.save_sharded_model = True - elif self.save_checkpoint_mode == "flex_checkpoint": - self.save_flex_checkpoint = True - else: - raise NotImplementedError(f"Checkpoint mode '{self.save_checkpoint_mode}' is not supported.") + self.save_checkpoint_format in valid_modes + ), f"Invalid save_checkpoint_format: {self.save_checkpoint_format}, Only these formats are allowed: {valid_modes}." - def _post_init_load_checkpoint_mode(self): - - self.load_flex_checkpoint = False - - if not self.load_checkpoint_mode: - return - - # Ensure that only one checkpoint mode is set at a time - if self.unified_checkpoint or self.load_sharded_model: - return - - valid_modes = ["unified_checkpoint", "sharding_io", "safetensor", "flex_checkpoint"] - assert ( - self.load_checkpoint_mode in valid_modes - ), f"Invalid load_checkpoint_mode: {self.load_checkpoint_mode}, Only these modes are allowed: {valid_modes}." - - if self.load_checkpoint_mode == "safetensor": - raise NotImplementedError("safetensor checkpoint loading is not implemented yet.") - elif self.load_checkpoint_mode == "unified_checkpoint": + if self.save_checkpoint_format == "safetensor": + raise NotImplementedError("safetensor checkpoint saving is not implemented yet.") + else: + if self.unified_checkpoint: + self.save_checkpoint_format = "unified_checkpoint" + elif self.save_sharded_model: + self.save_checkpoint_format = "sharding_io" + + def _post_init_load_checkpoint_format(self): + if self.load_checkpoint_format: + valid_modes = ["unified_checkpoint", "sharding_io", "safetensor", "flex_checkpoint"] assert ( - getattr(self, "save_checkpoint_mode", None) == "unified_checkpoint" - ), "When loading in unified_checkpoint mode, save_checkpoint_mode must also be 'unified_checkpoint'." - self.unified_checkpoint = True - elif self.load_checkpoint_mode == "sharding_io": - self.load_sharded_model = True - elif self.load_checkpoint_mode == "flex_checkpoint": - self.load_flex_checkpoint = True + self.load_checkpoint_format in valid_modes + ), f"Invalid load_checkpoint_format: {self.load_checkpoint_format}, Only these formats are allowed: {valid_modes}." + + if self.load_checkpoint_format == "safetensor": + raise NotImplementedError("safetensor checkpoint loading is not implemented yet.") else: - raise NotImplementedError(f"Checkpoint mode '{self.load_checkpoint_mode}' is not supported.") + if self.unified_checkpoint: + self.load_checkpoint_format = "unified_checkpoint" + elif self.load_sharded_model: + self.load_checkpoint_format = "sharding_io" def add_moe_comm_group(self): hybrid_configs = fleet.fleet._user_defined_strategy.hybrid_configs @@ -2556,7 +2531,7 @@ def should_save_model_state(self): return True elif self.enable_auto_parallel: return True - elif self.save_flex_checkpoint: + elif self.save_checkpoint_format == "flex_checkpoint": return False elif self.use_hybrid_parallel: # save on dataset rank 0 @@ -2576,14 +2551,16 @@ def should_save_sharding_stage1_model(self): if self.enable_auto_parallel: return False return ( - ShardingOption.SHARD_OP in self.sharding and self.sharding_parallel_degree > 1 and self.save_sharded_model + ShardingOption.SHARD_OP in self.sharding + and self.sharding_parallel_degree > 1 + and self.save_checkpoint_format == "sharding_io" ) @property def should_load_sharding_stage1_model(self): if self.enable_auto_parallel: return False - return self.load_sharded_model + return self.load_checkpoint_format == "sharding_io" @property def should_load_dataset(self): From f11a69ec74ed49a28e0bc14e5670d5cc8a96c86c Mon Sep 17 00:00:00 2001 From: zty-king <17786324919@163.com> Date: Sun, 21 Sep 2025 08:32:23 +0000 Subject: [PATCH 3/6] fix the ckpt_format adapt --- paddlenlp/trainer/trainer.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 87472ff518e5..051899ce6760 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -479,7 +479,7 @@ def _save_ckpt_func(state_dict, path, signal_path=None): or isinstance(self.model, ReFTModel) ): if ( - self.args.load_checkpoint_format == "unified_checkpoint" + self.args.save_checkpoint_format == "unified_checkpoint" and "skip_save_model_weight" in self.args.unified_checkpoint_config ): self.args.unified_checkpoint_config.remove("skip_save_model_weight") @@ -1509,10 +1509,7 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg): if isinstance(self.model, LoRAModel) or isinstance(self.model, PrefixModelForCausalLM): self._load_best_model_from_peft_checkpoint() else: - if ( - self.args.save_checkpoint_format == "unified_checkpoint" - or self.args.load_checkpoint_format == "unified_checkpoint" - ): + if self.args.load_checkpoint_format == "unified_checkpoint": self.unified_checkpoint_handler.load_unified_checkpoint( self.model, self.state.best_model_checkpoint, @@ -2357,8 +2354,10 @@ def get_expected_keys(inputs, keys): if ( hasattr(self.args, "enable_sharding_comm_overlap") and self.args.enable_sharding_comm_overlap - and self.args.load_checkpoint_format == "unified_checkpoint" - or self.args.load_checkpoint_format == "unified_checkpoint" + and ( + self.args.save_checkpoint_format == "unified_checkpoint" + or self.args.load_checkpoint_format == "unified_checkpoint" + ) and "split_param" in split_parallel_config(self.args.sharding_parallel_config) ): model.register_sharding_comm_overlap_hook(self.optimizer) @@ -2712,7 +2711,7 @@ def save_model( # For ckpt integrity paddle.save(self.state.global_step, os.path.join(output_dir, ".model_done")) if ( - self.args.load_checkpoint_format == "unified_checkpoint" + self.args.save_checkpoint_format == "unified_checkpoint" and "async_save" in self.args.unified_checkpoint_config and not self.is_in_train ): From 8afc8f098d966100b419809c63bcf37ae8f87d83 Mon Sep 17 00:00:00 2001 From: zty-king <17786324919@163.com> Date: Mon, 22 Sep 2025 07:29:32 +0000 Subject: [PATCH 4/6] delete the duplicate code --- paddlenlp/trainer/trainer.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 051899ce6760..e90634f91b8d 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -822,10 +822,6 @@ def create_zcc_manager(self, unwrapped_model, resume_from_checkpoint=None): if resume_from_checkpoint is not None: path = _add_variant(PADDLE_OPTIMIZER_NAME, self.args.optimizer_name_suffix) path = os.path.join(resume_from_checkpoint, path).replace("optimizer", "ema") - if self.args.zcc_save_ema_coef is not None and self.sharding_io is not None: - success, err_msg = self.sharding_io.check_same_strategy(resume_from_checkpoint) - else: - success, err_msg = True, None if self.args.zcc_save_ema_coef is not None and self.sharding_io is not None: success, err_msg = self.sharding_io.check_same_strategy(resume_from_checkpoint) else: @@ -836,11 +832,6 @@ def create_zcc_manager(self, unwrapped_model, resume_from_checkpoint=None): self.zcc_manager.set_ema_state_dict(path) else: logger.info(f"ZCC EMA does not load {path} because {err_msg}") - if success: - logger.info(f"ZCC EMA load from {path}") - self.zcc_manager.set_ema_state_dict(path) - else: - logger.info(f"ZCC EMA does not load {path} because {err_msg}") else: logger.info(f"ZCC EMA state dict not found, in: {path}") From f7ff471ba1d17c75ee8d004f7621069fdda1bbe0 Mon Sep 17 00:00:00 2001 From: zty-king <17786324919@163.com> Date: Mon, 22 Sep 2025 14:02:42 +0000 Subject: [PATCH 5/6] delete safetensor --- paddlenlp/trainer/training_args.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 092b7e08e355..35d387be80a2 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -408,9 +408,9 @@ class TrainingArguments: ckpt_quant_stage (`str`, *optional*): Whether activate checkpoint quantization. O0: deactivate, O1: Int8 compression, O2: Int4 compression. (default: O0). save_checkpoint_format (`str`, *optional*): - Specifies the format for saving checkpoints. Options are: None, 'sharding_io', 'unified_checkpoint', 'flex_checkpoint', and 'safetensor'. (default: None). This setting is ignored if the corresponding switch is configured. + Specifies the format for saving checkpoints. Options are: None, 'sharding_io', 'unified_checkpoint', 'flex_checkpoint'. (default: None). This setting is ignored if the corresponding switch is configured. load_checkpoint_format (`str`, *optional*): - Specifies the format for loading checkpoints. Options are: None, 'sharding_io', 'unified_checkpoint', 'flex_checkpoint', and 'safetensor'. (default: None). This setting is ignored if the corresponding switch is configured. + Specifies the format for loading checkpoints. Options are: None, 'sharding_io', 'unified_checkpoint', 'flex_checkpoint'. (default: None). This setting is ignored if the corresponding switch is configured. aoa_config (`Optional[dict[str, list[str]]]`, *optional*): The AoA configuration of FlexCheckpoint, used to describe the mapping between model weights and the checkpoint content. Default is None. """ @@ -953,7 +953,7 @@ class TrainingArguments: "help": ( "Specifies the format used to save checkpoints. " "Available options: 'sharding_io', 'unified_checkpoint', " - "'flex_checkpoint', 'safetensor'." + "'flex_checkpoint'." "This setting is ignored if the corresponding switch is configured." ) }, @@ -965,7 +965,7 @@ class TrainingArguments: "help": ( "Specifies the format used to load checkpoints. " "Available options: 'sharding_io', 'unified_checkpoint', " - "'flex_checkpoint', 'safetensor'." + "'flex_checkpoint'." "This setting is ignored if the corresponding switch is configured." ) }, @@ -2175,13 +2175,10 @@ def _post_init_parallel_degree(self): def _post_init_save_checkpoint_format(self): if self.save_checkpoint_format: - valid_modes = ["unified_checkpoint", "sharding_io", "safetensor", "flex_checkpoint"] + valid_modes = ["unified_checkpoint", "sharding_io", "flex_checkpoint"] assert ( self.save_checkpoint_format in valid_modes ), f"Invalid save_checkpoint_format: {self.save_checkpoint_format}, Only these formats are allowed: {valid_modes}." - - if self.save_checkpoint_format == "safetensor": - raise NotImplementedError("safetensor checkpoint saving is not implemented yet.") else: if self.unified_checkpoint: self.save_checkpoint_format = "unified_checkpoint" @@ -2190,13 +2187,10 @@ def _post_init_save_checkpoint_format(self): def _post_init_load_checkpoint_format(self): if self.load_checkpoint_format: - valid_modes = ["unified_checkpoint", "sharding_io", "safetensor", "flex_checkpoint"] + valid_modes = ["unified_checkpoint", "sharding_io", "flex_checkpoint"] assert ( self.load_checkpoint_format in valid_modes ), f"Invalid load_checkpoint_format: {self.load_checkpoint_format}, Only these formats are allowed: {valid_modes}." - - if self.load_checkpoint_format == "safetensor": - raise NotImplementedError("safetensor checkpoint loading is not implemented yet.") else: if self.unified_checkpoint: self.load_checkpoint_format = "unified_checkpoint" From f4b75264f54eb07b4f456b7ac6bb1115c14b8ceb Mon Sep 17 00:00:00 2001 From: xingmingyyj Date: Wed, 24 Sep 2025 20:09:59 +0800 Subject: [PATCH 6/6] Refactor checkpoint saving: separate model, optimizer, and master weights into different directories --- paddlenlp/trainer/trainer.py | 222 +++++++++++++++++++++++++++++++---- 1 file changed, 196 insertions(+), 26 deletions(-) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index e90634f91b8d..08a1ef0f8d9a 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -34,6 +34,7 @@ from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import deepcopy import numpy as np import paddle import paddle.amp.auto_cast as autocast @@ -49,6 +50,9 @@ except: core = None from paddle.distributed import fleet +from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import ( + DygraphShardingOptimizerV2, +) from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.hybrid_parallel_optimizer import ( HybridParallelOptimizer, ) @@ -97,6 +101,8 @@ except: pass +from paddle.distributed.flex_checkpoint.dcp.sharded_weight import ShardedWeight + from ..transformers.context_parallel_utils import split_inputs_sequence_dim_load_balance from ..transformers.model_utils import ( PretrainedModel, @@ -221,6 +227,11 @@ def in_auto_parallel_align_mode(): return False +MODEL_STATE_DIC = "model_state" +OPTIMIZER_STATE_DIC = "optimizer_state" +MASTER_WEIGHT_DIC = "master_weight" + + __all__ = ["Trainer"] @@ -837,6 +848,140 @@ def create_zcc_manager(self, unwrapped_model, resume_from_checkpoint=None): logger.info("Create zero cost checkpoint manager done.") + def _load_flex_checkpoint(self, resume_from_checkpoint): + model_sharded_state_dict = self.model.sharded_state_dict() + master_weights_path = os.path.join(resume_from_checkpoint, MASTER_WEIGHT_DIC) + opt_states_path = os.path.join(resume_from_checkpoint, OPTIMIZER_STATE_DIC) + model_states_path = os.path.join(resume_from_checkpoint, MODEL_STATE_DIC) + if not self.args.ignore_load_lr_and_optim: + state_dict_metadata = {} + metadata_paths = [ + os.path.join(model_states_path, "0.metadata"), + os.path.join(opt_states_path, "0.metadata"), + os.path.join(master_weights_path, "0.metadata"), + ] + + for metadata_file in metadata_paths: + if not os.path.exists(metadata_file): + raise FileNotFoundError(f"Metadata file not found: {metadata_file}") + metadata = paddle.load(metadata_file) + if hasattr(metadata, "state_dict_metadata"): + state_dict_metadata.update(metadata.state_dict_metadata) + else: + raise AttributeError( + f"Loaded metadata from {metadata_file} does not have 'state_dict_metadata' attribute" + ) + + init_optimizer(self.optimizer, model_sharded_state_dict, state_dict_metadata) + + optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict) + for k, v in optimizer_sharded_state_dict.items(): + v.local_tensor._clear_to_zero_allocation() + + if isinstance(self.optimizer.inner_opt, DygraphShardingOptimizerV2): + color_to_comm_buffer_list = self.optimizer._color_to_comm_buffer_list + for color, _comm_buffer_list in color_to_comm_buffer_list.items(): + for comm_buffer in _comm_buffer_list: + comm_buffer._clear_param_storage() + else: + state_dict = self.model.state_dict() + for k, v in state_dict.items(): + v._clear_to_zero_allocation() + + opt_states = {} + master_weights = {} + for k, v in optimizer_sharded_state_dict.items(): + if k.endswith(".w_0"): + master_weights[k] = v + else: + opt_states[k] = v + + for k, v in opt_states.items(): + new_v = ShardedWeight( + key=v.key, + local_tensor=paddle.zeros_like(v.local_tensor), + local_shape=deepcopy(v.local_shape), + global_shape=deepcopy(v.global_shape), + global_offset=deepcopy(v.global_offset), + is_flattened=v.is_flattened, + flattened_range=deepcopy(v.flattened_range), + ) + opt_states[k] = new_v + + dist.load_state_dict( + opt_states, + opt_states_path, + aoa_config=self.args.aoa_config, + ) + + optimizer_state_pin = {} + + for k, v in opt_states.items(): + tmp = v.local_tensor + optimizer_state_pin[k] = tmp.pin_memory() + tmp._clear_to_zero_allocation() + del tmp + + for k, v in master_weights.items(): + new_v = ShardedWeight( + key=v.key, + local_tensor=paddle.zeros_like(v.local_tensor), + local_shape=deepcopy(v.local_shape), + global_shape=deepcopy(v.global_shape), + global_offset=deepcopy(v.global_offset), + is_flattened=v.is_flattened, + flattened_range=deepcopy(v.flattened_range), + ) + master_weights[k] = new_v + + dist.load_state_dict( + master_weights, + master_weights_path, + aoa_config=self.args.aoa_config, + ) + + master_weights_pin = {} + + for k, v in master_weights.items(): + tmp = v.local_tensor + master_weights_pin[k] = tmp.pin_memory() + tmp._clear_to_zero_allocation() + del tmp + + optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict) + + optimizer_sharded_state_dict_pin = {**master_weights_pin, **optimizer_state_pin} + + for k, v in optimizer_sharded_state_dict.items(): + source_tensor = optimizer_sharded_state_dict_pin[k] + target_tensor = paddle.zeros_like(v.local_tensor) + if source_tensor.place != target_tensor.place: + source_tensor = source_tensor.to(target_tensor.place) + paddle.assign(source_tensor, target_tensor) + target_tensor_pin = target_tensor.cpu() + del target_tensor + target_tensor_pin._share_buffer_to(v.local_tensor) + del source_tensor + + if isinstance(self.optimizer.inner_opt, DygraphShardingOptimizerV2): + color_to_comm_buffer_list = self.optimizer._color_to_comm_buffer_list + for color, _comm_buffer_list in color_to_comm_buffer_list.items(): + for comm_buffer in _comm_buffer_list: + comm_buffer._reset_param_storage() + else: + state_dict = self.model.state_dict() + for k, v in state_dict.items(): + new_v = paddle.zeros_like(v) + new_v._share_buffer_to(v) + + self._load_scheduler(resume_from_checkpoint) + + dist.load_state_dict( + model_sharded_state_dict, + model_states_path, + aoa_config=self.args.aoa_config, + ) + def train( self, resume_from_checkpoint: Optional[Union[str, bool]] = None, @@ -970,28 +1115,8 @@ def train( self.model_wrapped = model if delay_optimizer_creation: self.create_optimizer_and_scheduler(num_training_steps=max_steps) - if resume_from_checkpoint is not None: - if not self.args.ignore_load_lr_and_optim: - model_sharded_state_dict = self.model.sharded_state_dict() - accessible_files = os.listdir(resume_from_checkpoint) - metadata_files = [file for file in accessible_files if file.endswith(".metadata")] - assert len(metadata_files) == 1, "Only support one metadata file now." - metadata = paddle.load(os.path.join(resume_from_checkpoint, metadata_files[0])) - state_dict_metadata = metadata.state_dict_metadata - init_optimizer(self.optimizer, model_sharded_state_dict, state_dict_metadata) - optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict) - sharded_state_dict = {**model_sharded_state_dict, **optimizer_sharded_state_dict} - dist.load_state_dict( - sharded_state_dict, resume_from_checkpoint, aoa_config=self.args.aoa_config - ) - self._load_scheduler(resume_from_checkpoint) - else: - model_sharded_state_dict = self.model.sharded_state_dict() - sharded_state_dict = model_sharded_state_dict - dist.load_state_dict( - sharded_state_dict, resume_from_checkpoint, aoa_config=self.args.aoa_config - ) + self._load_flex_checkpoint(resume_from_checkpoint) else: model = self._wrap_model(self.model_wrapped) # for the rest of this function `model` is the outside model, whether it was wrapped or not @@ -2779,7 +2904,12 @@ def _save_checkpoint(self, model, metrics=None): if self.args.save_checkpoint_format == "flex_checkpoint": model_sharded_state_dict = self.model.sharded_state_dict() - os.makedirs(output_dir, exist_ok=True) + model_state_dict_path = os.path.join(output_dir, MODEL_STATE_DIC) + os.makedirs(model_state_dict_path, exist_ok=True) + dist.save_state_dict( + model_sharded_state_dict, + model_state_dict_path, + ) # Determine the new best metric / best model checkpoint if metrics is not None and self.args.metric_for_best_model is not None: @@ -2843,10 +2973,26 @@ def _save_checkpoint(self, model, metrics=None): ) else: if self.args.save_checkpoint_format == "flex_checkpoint": + optimizer_state_dict_path = os.path.join(output_dir, OPTIMIZER_STATE_DIC) + optimizer_states = {} + master_weights = {} + + model_sharded_state_dict = self.model.sharded_state_dict() optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict) + for k, v in optimizer_sharded_state_dict.items(): + if k.endswith(".w_0"): + master_weights[k] = v + else: + optimizer_states[k] = v + dist.save_state_dict( - {**model_sharded_state_dict, **optimizer_sharded_state_dict}, - output_dir, + optimizer_states, + optimizer_state_dict_path, + ) + master_weights_path = os.path.join(output_dir, MASTER_WEIGHT_DIC) + dist.save_state_dict( + master_weights, + master_weights_path, ) if self.args.should_save: if self.tokenizer is not None and self.args.save_tokenizer: @@ -2904,10 +3050,34 @@ def _save_checkpoint(self, model, metrics=None): ) elif self.args.save_checkpoint_format == "flex_checkpoint": optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict) + model_sharded_state_dict = self.model.sharded_state_dict() + optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict) + model_state_dict_path = os.path.join(output_dir, MODEL_STATE_DIC) + os.makedirs(model_state_dict_path, exist_ok=True) dist.save_state_dict( - {**model_sharded_state_dict, **optimizer_sharded_state_dict}, - output_dir, + model_sharded_state_dict, + model_state_dict_path, ) + if not self.args.ignore_save_lr_and_optim: + optimizer_state_dict_path = os.path.join(output_dir, OPTIMIZER_STATE_DIC) + optimizer_states = {} + master_weights = {} + for k, v in optimizer_sharded_state_dict.items(): + if k.endswith(".w_0"): + master_weights[k] = v + else: + optimizer_states[k] = v + + dist.save_state_dict( + optimizer_states, + optimizer_state_dict_path, + ) + + master_weights_path = os.path.join(output_dir, MASTER_WEIGHT_DIC) + dist.save_state_dict( + master_weights, + master_weights_path, + ) if self.args.should_save: if self.tokenizer is not None and self.args.save_tokenizer: self.tokenizer.save_pretrained(output_dir)