@@ -156,6 +156,8 @@ def save_unified_checkpoint(args, model, optimizer, output_dir, safe_serializati
156
156
if args .should_save :
157
157
config_to_save .save_pretrained (save_directory )
158
158
159
+ paddle .device .cuda .empty_cache ()
160
+
159
161
160
162
def load_unified_checkpoint (args , model , optimizer , resume_from_checkpoint : str , safe_serialization = False ) -> None :
161
163
"""Load potential model checkpoint
@@ -281,6 +283,7 @@ def unified_checkpoint_into_shards(
281
283
Returns:
282
284
tuple: state_dict, config, shard_file: file name, sharded_index: map for weight to file name.
283
285
"""
286
+ paddle .device .cuda .empty_cache ()
284
287
assert hasattr (model_to_save , "config" )
285
288
286
289
state_dict = model_to_save .state_dict ()
@@ -311,6 +314,8 @@ def unified_checkpoint_into_shards(
311
314
total_size_list ,
312
315
)
313
316
317
+ paddle .device .cuda .empty_cache ()
318
+
314
319
return state_dict , shard_file , sharded_index
315
320
316
321
@@ -333,6 +338,8 @@ def save_unified_optimizer(args, model, optimizer, output_dir, safe_serializatio
333
338
optim_state_dict , shard_optim_file , sharded_optim_index = results [0 ]
334
339
master_weight_state_dict , shard_master_weight_file , sharded_master_weight_index = results [1 ]
335
340
341
+ paddle .device .cuda .empty_cache ()
342
+
336
343
save_directory = output_dir
337
344
os .makedirs (save_directory , exist_ok = True )
338
345
@@ -514,6 +521,7 @@ def unified_optimizer_into_shards(
514
521
optimizer (Optimizer): optimizer to save.
515
522
safe_serialization (bool, optional): safe serialization using safetensors. Defaults to False.
516
523
"""
524
+ paddle .device .cuda .empty_cache ()
517
525
optim_state_dict = nested_copy (optimizer .state_dict ())
518
526
master_weights = None
519
527
if "master_weights" in optim_state_dict .keys ():
@@ -559,12 +567,15 @@ def unified_optimizer_into_shards(
559
567
tp_actions ,
560
568
filter_optim_keys ,
561
569
)
570
+ paddle .device .cuda .empty_cache ()
571
+
562
572
if master_weights is not None :
563
573
master_weights = merge_tensor_parallel_for_optimizer (
564
574
master_weights ,
565
575
tp_actions ,
566
576
filter_master_keys ,
567
577
)
578
+ paddle .device .cuda .empty_cache ()
568
579
569
580
# build index json file
570
581
index_optimizer_file , index_master_weight_file = {}, {}
@@ -601,6 +612,7 @@ def unified_optimizer_into_shards(
601
612
else :
602
613
sharded_optim_index ["master_weights" ] = False
603
614
615
+ paddle .device .cuda .empty_cache ()
604
616
if master_weights is None :
605
617
return [(optim_state_dict , shard_optimizer_file , sharded_optim_index )]
606
618
else :
0 commit comments