@@ -61,7 +61,7 @@ def to_device(tensor, place=None):
61
61
return tensor
62
62
63
63
64
- def filter_sharded_params (state_dict , optimizer , sharding_group ):
64
+ def filter_sharded_params (state_dict , optimizer , sharding_group , include_freeze_params = False ):
65
65
66
66
sharding_rank = sharding_group .rank
67
67
sharding_world_size = sharding_group .nranks
@@ -80,7 +80,7 @@ def filter_sharded_params(state_dict, optimizer, sharding_group):
80
80
if sharded_rank != sharding_rank :
81
81
continue
82
82
filtered_state_dict [k ] = v
83
- else :
83
+ elif include_freeze_params :
84
84
if sharding_rank == 0 :
85
85
filtered_state_dict [k ] = v
86
86
else :
@@ -91,7 +91,7 @@ def filter_sharded_params(state_dict, optimizer, sharding_group):
91
91
for (k , v ) in state_dict .items ():
92
92
if v .name in filtered_parameters :
93
93
filtered_state_dict [k ] = v
94
- elif v .name not in [p .name for p in parameters ]:
94
+ elif include_freeze_params and ( v .name not in [p .name for p in parameters ]) :
95
95
if sharding_rank == 0 :
96
96
filtered_state_dict [k ] = v
97
97
return filtered_state_dict
@@ -375,7 +375,12 @@ def manipulate_state_dict_and_config(self, model_to_save, merge_tensor_parallel=
375
375
if state_dict is None :
376
376
state_dict = model_to_save .state_dict ()
377
377
if self .args .should_save_sharding_stage1_model :
378
- state_dict = filter_sharded_params (state_dict , self .optimizer , self .sharding_group )
378
+ state_dict = filter_sharded_params (
379
+ state_dict ,
380
+ self .optimizer ,
381
+ self .sharding_group ,
382
+ self .args .save_sharding_stage1_model_include_freeze_params ,
383
+ )
379
384
380
385
config_to_save = None
381
386
merge_tensor_parallel = merge_tensor_parallel and self .args .use_hybrid_parallel
0 commit comments