Skip to content

Commit da9ddf7

Browse files
umiswingbo-ke
andauthored
[feat] add params save_sharding_stage1_model_include_freeze_params (#9198) (#10551) (#11045)
Co-authored-by: Ferrebo <[email protected]>
1 parent abae892 commit da9ddf7

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

paddlenlp/trainer/utils/sharding_io.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def to_device(tensor, place=None):
6161
return tensor
6262

6363

64-
def filter_sharded_params(state_dict, optimizer, sharding_group):
64+
def filter_sharded_params(state_dict, optimizer, sharding_group, include_freeze_params=False):
6565

6666
sharding_rank = sharding_group.rank
6767
sharding_world_size = sharding_group.nranks
@@ -80,7 +80,7 @@ def filter_sharded_params(state_dict, optimizer, sharding_group):
8080
if sharded_rank != sharding_rank:
8181
continue
8282
filtered_state_dict[k] = v
83-
else:
83+
elif include_freeze_params:
8484
if sharding_rank == 0:
8585
filtered_state_dict[k] = v
8686
else:
@@ -91,7 +91,7 @@ def filter_sharded_params(state_dict, optimizer, sharding_group):
9191
for (k, v) in state_dict.items():
9292
if v.name in filtered_parameters:
9393
filtered_state_dict[k] = v
94-
elif v.name not in [p.name for p in parameters]:
94+
elif include_freeze_params and (v.name not in [p.name for p in parameters]):
9595
if sharding_rank == 0:
9696
filtered_state_dict[k] = v
9797
return filtered_state_dict
@@ -375,7 +375,12 @@ def manipulate_state_dict_and_config(self, model_to_save, merge_tensor_parallel=
375375
if state_dict is None:
376376
state_dict = model_to_save.state_dict()
377377
if self.args.should_save_sharding_stage1_model:
378-
state_dict = filter_sharded_params(state_dict, self.optimizer, self.sharding_group)
378+
state_dict = filter_sharded_params(
379+
state_dict,
380+
self.optimizer,
381+
self.sharding_group,
382+
self.args.save_sharding_stage1_model_include_freeze_params,
383+
)
379384

380385
config_to_save = None
381386
merge_tensor_parallel = merge_tensor_parallel and self.args.use_hybrid_parallel

0 commit comments

Comments
 (0)