From 8d6e8e71f7cb5cc4435645a737517db36bbe86de Mon Sep 17 00:00:00 2001 From: zty-king <17786324919@163.com> Date: Sat, 27 Sep 2025 11:25:41 +0000 Subject: [PATCH] fix_the_optimizer_init --- paddlenlp/trainer/trainer.py | 7 +--- paddlenlp/trainer/trainer_utils.py | 64 ++++++++---------------------- 2 files changed, 17 insertions(+), 54 deletions(-) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index a54fdb8d9c06..aa81f8340f43 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -979,12 +979,7 @@ def train( if resume_from_checkpoint is not None: if not self.args.ignore_load_lr_and_optim: model_sharded_state_dict = self.model.sharded_state_dict() - accessible_files = os.listdir(resume_from_checkpoint) - metadata_files = [file for file in accessible_files if file.endswith(".metadata")] - assert len(metadata_files) == 1, "Only support one metadata file now." - metadata = paddle.load(os.path.join(resume_from_checkpoint, metadata_files[0])) - state_dict_metadata = metadata.state_dict_metadata - init_optimizer(self.optimizer, model_sharded_state_dict, state_dict_metadata) + init_optimizer(self.optimizer) optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict) sharded_state_dict = {**model_sharded_state_dict, **optimizer_sharded_state_dict} dist.load_state_dict( diff --git a/paddlenlp/trainer/trainer_utils.py b/paddlenlp/trainer/trainer_utils.py index 0b9fa9ea5c16..ad5faab1bf77 100644 --- a/paddlenlp/trainer/trainer_utils.py +++ b/paddlenlp/trainer/trainer_utils.py @@ -1363,35 +1363,19 @@ def set_comm_config(configs, attr, dict_obj): return strategy -def init_optimizer(optimizer, model_sharded_state_dict, state_dict_metadata): +def init_optimizer(optimizer): """ Initialize the optimizer's states according to its type. - For DygraphShardingOptimizer (V1), initializes accumulators for local parameters. For DygraphShardingOptimizerV2, manually initializes master weights and state dict for sharded parameters. For other cases, initializes accumulators for all parameters. - Args: optimizer: The optimizer instance to be initialized. """ - optimizer_state_names = [".moment1_0", ".moment2_0", ".beta1_pow_acc_0", ".beta2_pow_acc_0", ".w_0"] inner_opt = getattr(optimizer, "_inner_opt", None) - static_to_struct_mapping = {} - model_sharded_state_dict = dict(sorted(model_sharded_state_dict.items())) - for k, v in model_sharded_state_dict.items(): - if v.local_tensor.name not in static_to_struct_mapping: - static_to_struct_mapping[v.local_tensor.name] = k - if isinstance(inner_opt, DygraphShardingOptimizer): local_params = optimizer._rank2params[optimizer._sharding_rank] - param_list = [] - for param in local_params: - param_name = param.name - struct_name = static_to_struct_mapping[param_name] - if not any(struct_name + state_name in state_dict_metadata for state_name in optimizer_state_names): - continue - param_list.append(param) - optimizer._create_accumulators(paddle.base.framework.default_main_program().global_block(), param_list) + optimizer._create_accumulators(paddle.base.framework.default_main_program().global_block(), local_params) return elif isinstance(inner_opt, DygraphShardingOptimizerV2): @@ -1399,49 +1383,33 @@ def init_optimizer(optimizer, model_sharded_state_dict, state_dict_metadata): def init_param_optimizer_states(param_iter): master_weights = {} state_dict = {} - moments = ("moment1_0", "moment2_0") - betas = ("beta1_pow_acc_0", "beta2_pow_acc_0") - for static_name, shape, no_need_master_weights in param_iter: - if not no_need_master_weights: - master_weights[static_name] = paddle.zeros(shape, dtype="float32") - prefix = f"{static_name}_fp32_master_0_" - else: - prefix = f"{static_name}_" - - for moment in moments: - key = f"{prefix}{moment}" + for static_name, shape in param_iter: + master_weights[static_name] = paddle.zeros(shape, dtype="float32") + for moment in ("moment1_0", "moment2_0"): + key = f"{static_name}_fp32_master_0_{moment}" state_dict[key] = paddle.zeros(shape, dtype="float32") - for beta in betas: - key = f"{prefix}{beta}" + for beta in ("beta1_pow_acc_0", "beta2_pow_acc_0"): + key = f"{static_name}_fp32_master_0_{beta}" state_dict[key] = paddle.zeros((1,), dtype="float32") return master_weights, state_dict def buffer_params(): for buffer in optimizer._comm_buffer_list: for param_name, grad_view in buffer._sharding_param_grad_view.items(): - struct_name = static_to_struct_mapping[param_name] - if not any( - struct_name + state_name in state_dict_metadata for state_name in optimizer_state_names - ): - continue + numel = grad_view._param.numel().item() param_begin = grad_view._param_begin param_end = grad_view._param_end - shape = (param_end - param_begin,) - no_need_master_weights = grad_view._param.dtype == paddle.float32 - + index = grad_view._index + padding_begin = index + numel + shape = (min(padding_begin, param_end) - param_begin,) if shape[0] > 0: - yield param_name, shape, no_need_master_weights + yield param_name, shape master_weights, state_dict = init_param_optimizer_states(buffer_params()) state_dict["master_weights"] = master_weights state_dict["LR_Scheduler"] = {"last_epoch": 1, "last_lr": 5e-06} optimizer.set_state_dict(state_dict) return - param_list = [] - for param in optimizer._parameter_list: - param_name = param.name - struct_name = static_to_struct_mapping[param_name] - if not any(struct_name + state_name in state_dict_metadata for state_name in optimizer_state_names): - continue - param_list.append(param) - optimizer._create_accumulators(paddle.base.framework.default_main_program().global_block(), param_list) + optimizer._create_accumulators( + paddle.base.framework.default_main_program().global_block(), optimizer._parameter_list + )