Skip to content

Commit 51f73a5

Browse files
disable only sharding opt in save stage (#11200)
1 parent 2ac8599 commit 51f73a5

File tree

2 files changed

+4
-35
lines changed

2 files changed

+4
-35
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -815,13 +815,7 @@ def get_metadata_file_name(path):
815815

816816
self._load_scheduler(resume_from_checkpoint)
817817

818-
from .trainer_utils import ShardingOption
819-
820-
should_load_stage1 = self.args.sharding_parallel_degree > 1 and ShardingOption.SHARD_OP in self.args.sharding
821-
logger.debug(f"should_load_stage1 = {should_load_stage1}")
822-
logger.debug(f"sharded_model_from_ema = {self.args.sharded_model_from_ema}")
823-
824-
if should_load_stage1 and self.args.sharded_model_from_ema:
818+
if self.args.sharded_model_from_ema:
825819
ema_states_path = os.path.join(resume_from_checkpoint, EMA_STATE_DIC, f"{dist.get_rank()}_0.distcp")
826820
ema_state_dict = paddle.load(ema_states_path)
827821
ema_master_weights = ema_state_dict.pop("master_weights", None)
@@ -832,7 +826,6 @@ def get_metadata_file_name(path):
832826
), f"{k} not in ema_master_weights, emas_master_weight keys {ema_master_weights.keys()}"
833827
paddle.assign(ema_master_weights[k], opt_master_weights[k])
834828

835-
ema_state_dict = reshard_util.all_gather_state_dict(ema_state_dict, lambda x: True, self.sharding_group)
836829
self.model.set_state_dict(ema_state_dict)
837830
else:
838831

@@ -854,7 +847,7 @@ def bf16_filtered_sharded_state_dict(sharded_state_dict):
854847
comm_method=self.args.flex_ckpt_comm_method,
855848
)
856849

857-
if self.args.bf16 and (not self.args.ignore_load_lr_and_optim) and should_load_stage1:
850+
if self.args.bf16 and (not self.args.ignore_load_lr_and_optim):
858851
opt_state_dict = self.optimizer.state_dict()
859852

860853
def recover_params_from_master_weight(opt_state_dict, group):

paddlenlp/trainer/utils/zero_cost_checkpoint.py

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1505,37 +1505,13 @@ def __init__(self, args, zcc_manager, timer, unused_arg):
15051505
self.sharding_group = self.hcg.get_sharding_parallel_group()
15061506

15071507
def _manipulate_state_dict_and_config(self, model_to_save, optimizer):
1508-
# return model_to_save.sharded_state_dict()
1509-
15101508
group_getter = GroupGetter(model_to_save)
15111509
gids = group_getter.get_group_ids()
1512-
from paddlenlp.trainer.utils.sharding_io import (
1513-
exclude_parameters_in_state_dict,
1514-
filter_sharded_params,
1515-
)
1516-
1517-
# filter_sharded_params = sharded_state_dict_compatibility(filter_sharded_params, return_sharded_state_dict=True)
1518-
# exclude_parameters_in_state_dict = sharded_state_dict_compatibility(
1519-
# exclude_parameters_in_state_dict, return_sharded_state_dict=True
1520-
# )
1510+
from paddlenlp.trainer.utils.sharding_io import exclude_parameters_in_state_dict
15211511

15221512
state_dict = model_to_save.state_dict()
1523-
# tmp wa should_save_sharding_stage1_model
1524-
if self.args.should_save_sharding_stage1_model or self.args.save_checkpoint_format == "flex_checkpoint":
1525-
state_dict = split_model_state(state_dict, group_getter)
1526-
for gid in gids:
1527-
state_dict[gid] = filter_sharded_params(
1528-
state_dict.get(gid, {}),
1529-
optimizer,
1530-
self.sharding_group,
1531-
self.args.save_sharding_stage1_model_include_freeze_params,
1532-
)
1533-
state_dict = merge_model_state(state_dict)
15341513

1535-
# tmp wa should_save_sharding_stage1_model
1536-
if self.args.bf16 and (
1537-
self.args.should_save_sharding_stage1_model or self.args.save_checkpoint_format == "flex_checkpoint"
1538-
):
1514+
if self.args.bf16:
15391515
param_names_in_master_weights = []
15401516
optimzier_state_dict = optimizer.state_dict()
15411517
optimzier_state_dict = split_opt_state(optimzier_state_dict, group_getter)

0 commit comments

Comments
 (0)