30
30
from paddlenlp .transformers .model_utils import (
31
31
PretrainedModel ,
32
32
_load_state_dict_into_model ,
33
+ faster_set_state_dict ,
33
34
get_parameter_dtype ,
34
35
load_state_dict ,
35
36
unwrap_model ,
65
66
from paddlenlp .utils .nested import nested_copy , nested_copy_place
66
67
67
68
if is_safetensors_available ():
68
- from safetensors import safe_open
69
+ # from safetensors import safe_open
69
70
from safetensors .numpy import save_file as safe_save_file
70
71
72
+ from paddlenlp .utils .safetensors import fast_safe_open as safe_open
71
73
72
74
FP32_MASTER = "fp32_master_0"
73
75
optimizer_scalar_name = [
91
93
async_save_queue = []
92
94
93
95
96
+ DEST_PLACE = paddle .CPUPlace ()
97
+ if paddle .device .is_compiled_with_cuda ():
98
+ DEST_PLACE = paddle .CUDAPinnedPlace ()
99
+
100
+
94
101
class UnifiedCheckpointOption (ExplicitEnum ):
95
102
"""
96
103
"- skip_save_model_weight: do not save model weights when the masters weight exist\n "
@@ -196,7 +203,6 @@ def load_unified_checkpoint(args, model, optimizer, resume_from_checkpoint: str,
196
203
Returns:
197
204
None
198
205
"""
199
-
200
206
if paddle .distributed .get_world_size () <= 1 :
201
207
load_single_card_checkpoint (args , model , resume_from_checkpoint )
202
208
return
@@ -222,7 +228,6 @@ def load_unified_checkpoint_locally(args, model, resume_from_checkpoint: str, sa
222
228
pretrained_model_name_or_path = resume_from_checkpoint ,
223
229
index_filename = os .path .join (resume_from_checkpoint , index_filename ),
224
230
)
225
-
226
231
loaded_keys = sharded_metadata ["all_checkpoint_keys" ]
227
232
228
233
model_state_dict = get_expected_state_dict (model )
@@ -266,7 +271,9 @@ def _remove_unused_keys(
266
271
else :
267
272
tp_actions = model .get_tensor_parallel_convert_actions (model .config , loaded_keys , ignore_error = True )
268
273
# Here we use expected_keys to optimize weights loading for pipeline model. Only works for safetensors
269
- state_dict = load_state_dict (shard_file , tp_actions if pre_tensor_parallel_split else None , expected_keys )
274
+ state_dict = load_state_dict (
275
+ shard_file , tp_actions if pre_tensor_parallel_split else None , expected_keys , device = "expected"
276
+ )
270
277
271
278
if not pre_tensor_parallel_split :
272
279
# Since we load all keys but we only need one of pipeline stages
@@ -279,11 +286,12 @@ def _remove_unused_keys(
279
286
None , model .config , state_dict = state_dict , ignore_error = len (resolved_archive_file ) > 1
280
287
)
281
288
282
- error_msgs += _load_state_dict_into_model (model , state_dict , "" )
289
+ # error_msgs += _load_state_dict_into_model(model, state_dict, "")
290
+ error_msgs += faster_set_state_dict (model , state_dict , strict_dtype = False )
283
291
284
292
# force memory release
285
293
del state_dict
286
- gc .collect ()
294
+ # gc.collect()
287
295
288
296
if len (error_msgs ) > 0 :
289
297
error_msg = "\n \t " .join (error_msgs )
@@ -337,6 +345,7 @@ def unified_checkpoint_into_shards(
337
345
tp_actions = model_to_save .get_tensor_parallel_convert_actions (
338
346
model_to_save .config , state_dict .keys (), is_split = False , ignore_error = True
339
347
)
348
+ logger .info ("Unified model tensor parallel weights in shards" )
340
349
state_dict = merge_tensor_parallel_with_shard (state_dict , tp_actions , all_filter_keys )
341
350
342
351
# build index json file
@@ -490,6 +499,7 @@ def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoin
490
499
# This should always be a list but, just to be sure.
491
500
if not isinstance (resolved_archive_file , list ):
492
501
resolved_archive_file = [resolved_archive_file ]
502
+
493
503
if len (resolved_archive_file ) > 1 :
494
504
resolved_archive_file = tqdm (resolved_archive_file , desc = "Loading optimizer shards" )
495
505
@@ -537,10 +547,10 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
537
547
tp_actions = mapping_optimizer_tp_actions (tp_actions , expected_keys )
538
548
539
549
# Here we use expected_keys to optimize weights loading for pipeline model. Only works for safetensors
540
- state_dict = load_state_dict (shard_file , tp_actions , expected_keys )
550
+ state_dict = load_state_dict (shard_file , tp_actions , expected_keys , device = "expected" )
541
551
else :
542
552
# for pipeline model, we don't need to use tp_actions
543
- state_dict = load_state_dict (shard_file , None , expected_keys )
553
+ state_dict = load_state_dict (shard_file , None , expected_keys , device = "expected" )
544
554
545
555
returned_state_dict .update (state_dict )
546
556
# force memory release
@@ -553,7 +563,6 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
553
563
state_dict_master_weight = load_resolved_archive_file (
554
564
resolved_archive_file_mw , sharded_metadata_mw , expected_keys_mw , is_master_weights = True
555
565
)
556
-
557
566
# rename optimizer param
558
567
for key in list (state_dict_optim .keys ()):
559
568
key_name = key .split ("/" )
@@ -562,13 +571,13 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
562
571
key_name = "_" .join ([static_name , FP32_MASTER , key_name [1 ]])
563
572
else :
564
573
key_name = "_" .join ([static_name , key_name [1 ]])
565
- returned_optim_state_dict [key_name ] = state_dict_optim [ key ]
574
+ returned_optim_state_dict [key_name ] = state_dict_optim . pop ( key )
566
575
returned_optim_state_dict [key_name ].name = key_name
567
576
568
577
if has_master_weights :
569
578
for key in list (state_dict_master_weight .keys ()):
570
579
static_name = struct2static_name_mappings [key ]
571
- returned_optim_state_dict ["master_weights" ][static_name ] = state_dict_master_weight [ key ]
580
+ returned_optim_state_dict ["master_weights" ][static_name ] = state_dict_master_weight . pop ( key )
572
581
returned_optim_state_dict ["master_weights" ][static_name ].name = "_" .join ([static_name , FP32_MASTER ])
573
582
574
583
returned_optim_state_dict = nested_copy_place (
@@ -640,6 +649,7 @@ def unified_optimizer_into_shards(
640
649
tp_actions = model .get_tensor_parallel_convert_actions (
641
650
model .config , model_keys , is_split = False , ignore_error = True
642
651
)
652
+ logger .info ("Unified optimizer tensor parallel in shards" )
643
653
optim_state_dict = merge_tensor_parallel_for_optimizer (
644
654
optim_state_dict ,
645
655
tp_actions ,
@@ -648,6 +658,7 @@ def unified_optimizer_into_shards(
648
658
paddle .device .cuda .empty_cache ()
649
659
650
660
if master_weights is not None :
661
+ logger .info ("Unified master weight tensor parallel in shards" )
651
662
master_weights = merge_tensor_parallel_for_optimizer (
652
663
master_weights ,
653
664
tp_actions ,
@@ -703,7 +714,6 @@ def unified_optimizer_into_shards(
703
714
def check_unified_checkpoint (args , model , resume_from_checkpoint , safe_serialization = False ):
704
715
index_filename = select_model_weight_index (args , model , resume_from_checkpoint , safe_serialization , local = False )
705
716
index_filename = os .path .join (resume_from_checkpoint , index_filename )
706
-
707
717
# Find index json file and distribute this file in global group.
708
718
if distributed_isfile (index_filename ):
709
719
distributed_file (index_filename )
@@ -1605,7 +1615,9 @@ def gather_sharded_object(index_file, total_size, is_optimizer=False):
1605
1615
tp_group = hcg .get_model_parallel_group ()
1606
1616
pp_group = hcg .get_pipe_parallel_group ()
1607
1617
1608
- logger .info ("Unified checkpoint generating sharded_index json files." )
1618
+ logger .info (
1619
+ f"Unified checkpoint: generating sharded_index json files for { 'optimizer or master weight' if is_optimizer else 'model weight' } ."
1620
+ )
1609
1621
1610
1622
if tp_group .nranks > 1 :
1611
1623
dist .all_gather_object (index_file_list , index_file , tp_group )
@@ -1714,8 +1726,6 @@ def filter_params(model_to_save, state_dict, is_optimizer=False):
1714
1726
1715
1727
1716
1728
def merge_tensor_parallel_with_shard (state_dict , tp_actions , all_filter_keys ):
1717
- logger .info ("Unified checkpoint merge tensor parallel in shards" )
1718
-
1719
1729
hcg = fleet .get_hybrid_communicate_group ()
1720
1730
tp_group = hcg .get_model_parallel_group ()
1721
1731
tp_rank = tp_group .rank
@@ -1741,7 +1751,7 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):
1741
1751
action = tp_actions .pop (key )
1742
1752
tensor = action (ret ) if is_dst else None
1743
1753
else :
1744
- tensor = tensor ._copy_to (paddle . CPUPlace () , False ) if is_dst else None
1754
+ tensor = tensor ._copy_to (DEST_PLACE , False ) if is_dst else None
1745
1755
1746
1756
if is_dst :
1747
1757
state_dict_to_save [key ] = tensor
@@ -1754,8 +1764,7 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):
1754
1764
1755
1765
1756
1766
def merge_tensor_parallel_for_optimizer (state_dict , tp_actions , all_filter_keys ):
1757
- logger .info ("Unified optimizer tensor parallel in shards" )
1758
-
1767
+ # Core function for UC
1759
1768
hcg = fleet .get_hybrid_communicate_group ()
1760
1769
tp_group = hcg .get_model_parallel_group ()
1761
1770
tp_rank = tp_group .rank
@@ -1773,15 +1782,13 @@ def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys)
1773
1782
if model_key in tp_actions :
1774
1783
# for example: beta1, beta2
1775
1784
if tensor .numel ().item () == 1 :
1776
- tensor = (
1777
- tensor ._copy_to (paddle .CPUPlace (), False ) if is_dst else None
1778
- ) # Need broadcast when loaded
1785
+ tensor = tensor ._copy_to (DEST_PLACE , False ) if is_dst else None # Need broadcast when loaded
1779
1786
else :
1780
1787
ret = distributed_gather (tensor , dst = j , group = tp_group , offload = False )
1781
1788
action = tp_actions [model_key ]
1782
1789
tensor = action (ret ) if is_dst else None
1783
1790
else :
1784
- tensor = tensor ._copy_to (paddle . CPUPlace () , False ) if is_dst else None
1791
+ tensor = tensor ._copy_to (DEST_PLACE , False ) if is_dst else None
1785
1792
1786
1793
if is_dst :
1787
1794
state_dict_to_save [filter_keys [i ]] = tensor
0 commit comments