diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index bacfa3bdf46f..82e7b7040ec0 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -815,13 +815,7 @@ def get_metadata_file_name(path): self._load_scheduler(resume_from_checkpoint) - from .trainer_utils import ShardingOption - - should_load_stage1 = self.args.sharding_parallel_degree > 1 and ShardingOption.SHARD_OP in self.args.sharding - logger.debug(f"should_load_stage1 = {should_load_stage1}") - logger.debug(f"sharded_model_from_ema = {self.args.sharded_model_from_ema}") - - if should_load_stage1 and self.args.sharded_model_from_ema: + if self.args.sharded_model_from_ema: ema_states_path = os.path.join(resume_from_checkpoint, EMA_STATE_DIC, f"{dist.get_rank()}_0.distcp") ema_state_dict = paddle.load(ema_states_path) ema_master_weights = ema_state_dict.pop("master_weights", None) @@ -832,7 +826,6 @@ def get_metadata_file_name(path): ), f"{k} not in ema_master_weights, emas_master_weight keys {ema_master_weights.keys()}" paddle.assign(ema_master_weights[k], opt_master_weights[k]) - ema_state_dict = reshard_util.all_gather_state_dict(ema_state_dict, lambda x: True, self.sharding_group) self.model.set_state_dict(ema_state_dict) else: @@ -854,7 +847,7 @@ def bf16_filtered_sharded_state_dict(sharded_state_dict): comm_method=self.args.flex_ckpt_comm_method, ) - if self.args.bf16 and (not self.args.ignore_load_lr_and_optim) and should_load_stage1: + if self.args.bf16 and (not self.args.ignore_load_lr_and_optim): opt_state_dict = self.optimizer.state_dict() def recover_params_from_master_weight(opt_state_dict, group): diff --git a/paddlenlp/trainer/utils/zero_cost_checkpoint.py b/paddlenlp/trainer/utils/zero_cost_checkpoint.py index 0b4b4a57990a..bb2c316a7164 100644 --- a/paddlenlp/trainer/utils/zero_cost_checkpoint.py +++ b/paddlenlp/trainer/utils/zero_cost_checkpoint.py @@ -1505,37 +1505,13 @@ def __init__(self, args, zcc_manager, timer, unused_arg): self.sharding_group = self.hcg.get_sharding_parallel_group() def _manipulate_state_dict_and_config(self, model_to_save, optimizer): - # return model_to_save.sharded_state_dict() - group_getter = GroupGetter(model_to_save) gids = group_getter.get_group_ids() - from paddlenlp.trainer.utils.sharding_io import ( - exclude_parameters_in_state_dict, - filter_sharded_params, - ) - - # filter_sharded_params = sharded_state_dict_compatibility(filter_sharded_params, return_sharded_state_dict=True) - # exclude_parameters_in_state_dict = sharded_state_dict_compatibility( - # exclude_parameters_in_state_dict, return_sharded_state_dict=True - # ) + from paddlenlp.trainer.utils.sharding_io import exclude_parameters_in_state_dict state_dict = model_to_save.state_dict() - # tmp wa should_save_sharding_stage1_model - if self.args.should_save_sharding_stage1_model or self.args.save_checkpoint_format == "flex_checkpoint": - state_dict = split_model_state(state_dict, group_getter) - for gid in gids: - state_dict[gid] = filter_sharded_params( - state_dict.get(gid, {}), - optimizer, - self.sharding_group, - self.args.save_sharding_stage1_model_include_freeze_params, - ) - state_dict = merge_model_state(state_dict) - # tmp wa should_save_sharding_stage1_model - if self.args.bf16 and ( - self.args.should_save_sharding_stage1_model or self.args.save_checkpoint_format == "flex_checkpoint" - ): + if self.args.bf16: param_names_in_master_weights = [] optimzier_state_dict = optimizer.state_dict() optimzier_state_dict = split_opt_state(optimzier_state_dict, group_getter)