@@ -115,44 +115,47 @@ def save_unified_checkpoint(args, model, optimizer, output_dir, safe_serializati
115
115
else :
116
116
raise ValueError ("Unified checkpoint only supports PretrainedModel" )
117
117
118
+ skip_save_model_weight = False
118
119
if UnifiedCheckpointOption .SKIP_SAVE_MODEL_WEIGHT .value in args .unified_checkpoint_config :
119
120
if is_need_master_weight (optimizer , is_fp16_or_bp16 = (args .fp16 or args .bf16 )):
120
121
logger .info (
121
122
f"With { UnifiedCheckpointOption .SKIP_SAVE_MODEL_WEIGHT .value } , skip the model checkpoint save."
122
123
"The master weight will be loaded as model weights for next resumption."
123
124
)
124
125
# not save model weight, load from master weight
125
- return
126
- config_to_save = None
127
- state_dict , config_to_save , shard_file , sharded_index = unified_checkpoint_into_shards (
128
- args , model_to_save , safe_serialization = safe_serialization
129
- )
126
+ skip_save_model_weight = True
130
127
131
128
save_directory = output_dir
132
129
os .makedirs (save_directory , exist_ok = True )
133
130
134
- is_sync_save = True
135
- if "async_save" in args .unified_checkpoint_config :
136
- is_sync_save = False
137
- file_save_async_or_sync (
138
- state_dict , os .path .join (save_directory , shard_file ), safe_serialization , is_sync = is_sync_save
139
- )
131
+ # save model weights
132
+ if not skip_save_model_weight :
133
+ state_dict , shard_file , sharded_index = unified_checkpoint_into_shards (
134
+ args , model_to_save , safe_serialization = safe_serialization
135
+ )
136
+ is_sync_save = True
137
+ if "async_save" in args .unified_checkpoint_config :
138
+ is_sync_save = False
139
+ file_save_async_or_sync (
140
+ state_dict , os .path .join (save_directory , shard_file ), safe_serialization , is_sync = is_sync_save
141
+ )
142
+
143
+ if sharded_index is not None :
144
+ if not safe_serialization :
145
+ path = os .path .join (output_dir , PADDLE_WEIGHTS_INDEX_NAME )
146
+ else :
147
+ path = os .path .join (output_dir , SAFE_WEIGHTS_INDEX_NAME )
148
+
149
+ with open (path , "w" ) as f :
150
+ json .dump (sharded_index , f , indent = 4 )
140
151
152
+ # save the config
153
+ config_to_save = save_config (model_to_save )
141
154
# Attach architecture to the config
142
155
config_to_save .architectures = [model_to_save .__class__ .__name__ ]
143
- # Save the config
144
156
if args .should_save :
145
157
config_to_save .save_pretrained (save_directory )
146
158
147
- if sharded_index is not None :
148
- if not safe_serialization :
149
- path = os .path .join (output_dir , PADDLE_WEIGHTS_INDEX_NAME )
150
- else :
151
- path = os .path .join (output_dir , SAFE_WEIGHTS_INDEX_NAME )
152
-
153
- with open (path , "w" ) as f :
154
- json .dump (sharded_index , f , indent = 4 )
155
-
156
159
157
160
def load_unified_checkpoint (args , model , optimizer , resume_from_checkpoint : str , safe_serialization = False ) -> None :
158
161
"""Load potential model checkpoint
@@ -252,6 +255,18 @@ def _remove_unused_keys(
252
255
raise RuntimeError (f"Error(s) in loading state_dict for { model .__class__ .__name__ } :\n \t { error_msg } " )
253
256
254
257
258
+ def save_config (model_to_save ):
259
+ dtype = get_parameter_dtype (model_to_save )
260
+ model_to_save .config .dtype = str (dtype ).split ("." )[1 ]
261
+ config_to_save = copy .deepcopy (model_to_save .config )
262
+
263
+ if config_to_save .tensor_parallel_degree > 1 :
264
+ # do we need to change?
265
+ config_to_save .tensor_parallel_degree = 1
266
+
267
+ return config_to_save
268
+
269
+
255
270
def unified_checkpoint_into_shards (
256
271
args ,
257
272
model_to_save ,
@@ -272,8 +287,6 @@ def unified_checkpoint_into_shards(
272
287
273
288
all_filter_keys = filter_params (model_to_save , state_dict )
274
289
275
- dtype = get_parameter_dtype (model_to_save )
276
- model_to_save .config .dtype = str (dtype ).split ("." )[1 ]
277
290
config_to_save = copy .deepcopy (model_to_save .config )
278
291
279
292
if config_to_save .tensor_parallel_degree > 1 :
@@ -282,10 +295,6 @@ def unified_checkpoint_into_shards(
282
295
)
283
296
state_dict = merge_tensor_parallel_with_shard (state_dict , tp_actions , all_filter_keys )
284
297
285
- if config_to_save .tensor_parallel_degree > 1 :
286
- # do we need to change?
287
- config_to_save .tensor_parallel_degree = 1
288
-
289
298
# build index json file
290
299
index_weight_file = {}
291
300
total_size = 0
@@ -302,7 +311,7 @@ def unified_checkpoint_into_shards(
302
311
total_size_list ,
303
312
)
304
313
305
- return state_dict , config_to_save , shard_file , sharded_index
314
+ return state_dict , shard_file , sharded_index
306
315
307
316
308
317
def save_unified_optimizer (args , model , optimizer , output_dir , safe_serialization = False ):
@@ -343,16 +352,17 @@ def save_unified_optimizer(args, model, optimizer, output_dir, safe_serializatio
343
352
)
344
353
345
354
if sharded_optim_index is not None :
346
- if not safe_serialization :
347
- path = os .path .join (output_dir , PADDLE_OPTIMIZER_INDEX_NAME )
348
- master_path = os .path .join (output_dir , PADDLE_MASTER_WEIGHTS_INDEX_NAME )
349
- else :
350
- path = os .path .join (output_dir , SAFE_OPTIMIZER_INDEX_NAME )
351
- master_path = os .path .join (output_dir , SAFE_MASTER_WEIGHTS_INDEX_NAME )
352
-
355
+ optimizer_index_name = SAFE_OPTIMIZER_INDEX_NAME if safe_serialization else PADDLE_OPTIMIZER_INDEX_NAME
356
+ path = os .path .join (output_dir , optimizer_index_name )
353
357
with open (path , "w" ) as f :
354
358
json .dump (sharded_optim_index , f , indent = 4 )
355
359
360
+ master_weights_name = (
361
+ SAFE_MASTER_WEIGHTS_INDEX_NAME if safe_serialization else PADDLE_MASTER_WEIGHTS_INDEX_NAME
362
+ )
363
+ if UnifiedCheckpointOption .SKIP_SAVE_MODEL_WEIGHT .value in args .unified_checkpoint_config :
364
+ master_weights_name = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else PADDLE_WEIGHTS_INDEX_NAME
365
+ master_path = os .path .join (output_dir , master_weights_name )
356
366
if master_weight_state_dict is not None :
357
367
with open (master_path , "w" ) as f :
358
368
json .dump (sharded_master_weight_index , f , indent = 4 )
@@ -561,6 +571,8 @@ def unified_optimizer_into_shards(
561
571
total_optim_size , total_master_weight_size = 0 , 0
562
572
optimizer_name = SAFE_OPTIMIZER_NAME if safe_serialization else PADDLE_OPTIMIZER_NAME
563
573
master_weights_name = SAFE_MASTER_WEIGHTS_NAME if safe_serialization else PADDLE_MASTER_WEIGHTS_NAME
574
+ if UnifiedCheckpointOption .SKIP_SAVE_MODEL_WEIGHT .value in args .unified_checkpoint_config :
575
+ master_weights_name = SAFE_WEIGHTS_NAME if safe_serialization else PADDLE_WEIGHTS_NAME
564
576
shard_optimizer_file = get_sharded_file_name (args , optimizer_name , is_optimizer = True )
565
577
shard_master_weight_file = get_sharded_file_name (args , master_weights_name , is_optimizer = True )
566
578
@@ -1648,6 +1660,10 @@ def update_master_weight_status(args, optimizer, has_master_weight, safe_seriali
1648
1660
index_filename_master_weights = (
1649
1661
PADDLE_MASTER_WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_MASTER_WEIGHTS_INDEX_NAME
1650
1662
)
1663
+ if UnifiedCheckpointOption .SKIP_SAVE_MODEL_WEIGHT .value in args .unified_checkpoint_config :
1664
+ index_filename_master_weights = (
1665
+ PADDLE_WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_WEIGHTS_INDEX_NAME
1666
+ )
1651
1667
else :
1652
1668
has_master_weight = False
1653
1669
index_filename_master_weights = None
0 commit comments