Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,13 +815,7 @@ def get_metadata_file_name(path):

self._load_scheduler(resume_from_checkpoint)

from .trainer_utils import ShardingOption

should_load_stage1 = self.args.sharding_parallel_degree > 1 and ShardingOption.SHARD_OP in self.args.sharding
logger.debug(f"should_load_stage1 = {should_load_stage1}")
logger.debug(f"sharded_model_from_ema = {self.args.sharded_model_from_ema}")

if should_load_stage1 and self.args.sharded_model_from_ema:
if self.args.sharded_model_from_ema:
ema_states_path = os.path.join(resume_from_checkpoint, EMA_STATE_DIC, f"{dist.get_rank()}_0.distcp")
ema_state_dict = paddle.load(ema_states_path)
ema_master_weights = ema_state_dict.pop("master_weights", None)
Expand All @@ -832,7 +826,6 @@ def get_metadata_file_name(path):
), f"{k} not in ema_master_weights, emas_master_weight keys {ema_master_weights.keys()}"
paddle.assign(ema_master_weights[k], opt_master_weights[k])

ema_state_dict = reshard_util.all_gather_state_dict(ema_state_dict, lambda x: True, self.sharding_group)
self.model.set_state_dict(ema_state_dict)
else:

Expand All @@ -854,7 +847,7 @@ def bf16_filtered_sharded_state_dict(sharded_state_dict):
comm_method=self.args.flex_ckpt_comm_method,
)

if self.args.bf16 and (not self.args.ignore_load_lr_and_optim) and should_load_stage1:
if self.args.bf16 and (not self.args.ignore_load_lr_and_optim):
opt_state_dict = self.optimizer.state_dict()

def recover_params_from_master_weight(opt_state_dict, group):
Expand Down
28 changes: 2 additions & 26 deletions paddlenlp/trainer/utils/zero_cost_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1505,37 +1505,13 @@ def __init__(self, args, zcc_manager, timer, unused_arg):
self.sharding_group = self.hcg.get_sharding_parallel_group()

def _manipulate_state_dict_and_config(self, model_to_save, optimizer):
# return model_to_save.sharded_state_dict()

group_getter = GroupGetter(model_to_save)
gids = group_getter.get_group_ids()
from paddlenlp.trainer.utils.sharding_io import (
exclude_parameters_in_state_dict,
filter_sharded_params,
)

# filter_sharded_params = sharded_state_dict_compatibility(filter_sharded_params, return_sharded_state_dict=True)
# exclude_parameters_in_state_dict = sharded_state_dict_compatibility(
# exclude_parameters_in_state_dict, return_sharded_state_dict=True
# )
from paddlenlp.trainer.utils.sharding_io import exclude_parameters_in_state_dict

state_dict = model_to_save.state_dict()
# tmp wa should_save_sharding_stage1_model
if self.args.should_save_sharding_stage1_model or self.args.save_checkpoint_format == "flex_checkpoint":
state_dict = split_model_state(state_dict, group_getter)
for gid in gids:
state_dict[gid] = filter_sharded_params(
state_dict.get(gid, {}),
optimizer,
self.sharding_group,
self.args.save_sharding_stage1_model_include_freeze_params,
)
state_dict = merge_model_state(state_dict)

# tmp wa should_save_sharding_stage1_model
if self.args.bf16 and (
self.args.should_save_sharding_stage1_model or self.args.save_checkpoint_format == "flex_checkpoint"
):
if self.args.bf16:
param_names_in_master_weights = []
optimzier_state_dict = optimizer.state_dict()
optimzier_state_dict = split_opt_state(optimzier_state_dict, group_getter)
Expand Down
Loading