@@ -1505,37 +1505,13 @@ def __init__(self, args, zcc_manager, timer, unused_arg):
15051505 self .sharding_group = self .hcg .get_sharding_parallel_group ()
15061506
15071507 def _manipulate_state_dict_and_config (self , model_to_save , optimizer ):
1508- # return model_to_save.sharded_state_dict()
1509-
15101508 group_getter = GroupGetter (model_to_save )
15111509 gids = group_getter .get_group_ids ()
1512- from paddlenlp .trainer .utils .sharding_io import (
1513- exclude_parameters_in_state_dict ,
1514- filter_sharded_params ,
1515- )
1516-
1517- # filter_sharded_params = sharded_state_dict_compatibility(filter_sharded_params, return_sharded_state_dict=True)
1518- # exclude_parameters_in_state_dict = sharded_state_dict_compatibility(
1519- # exclude_parameters_in_state_dict, return_sharded_state_dict=True
1520- # )
1510+ from paddlenlp .trainer .utils .sharding_io import exclude_parameters_in_state_dict
15211511
15221512 state_dict = model_to_save .state_dict ()
1523- # tmp wa should_save_sharding_stage1_model
1524- if self .args .should_save_sharding_stage1_model or self .args .save_checkpoint_format == "flex_checkpoint" :
1525- state_dict = split_model_state (state_dict , group_getter )
1526- for gid in gids :
1527- state_dict [gid ] = filter_sharded_params (
1528- state_dict .get (gid , {}),
1529- optimizer ,
1530- self .sharding_group ,
1531- self .args .save_sharding_stage1_model_include_freeze_params ,
1532- )
1533- state_dict = merge_model_state (state_dict )
15341513
1535- # tmp wa should_save_sharding_stage1_model
1536- if self .args .bf16 and (
1537- self .args .should_save_sharding_stage1_model or self .args .save_checkpoint_format == "flex_checkpoint"
1538- ):
1514+ if self .args .bf16 :
15391515 param_names_in_master_weights = []
15401516 optimzier_state_dict = optimizer .state_dict ()
15411517 optimzier_state_dict = split_opt_state (optimzier_state_dict , group_getter )
0 commit comments