3636 get_expected_state_dict ,
3737 get_optimizer_shard_files ,
3838 mapping_optimizer_tp_actions ,
39+ update_master_weight_status ,
3940)
4041
4142__all__ = ["gather_splited_param_for_optimizer" , "load_unified_optimizer_split_param" ]
4243
4344
4445def merge_splited_param (
45- state_dict , partial_tensor_list , param_shape_info , send_table , recv_table , is_master_weights = False
46+ state_dict ,
47+ partial_tensor_list ,
48+ param_shape_info ,
49+ send_table ,
50+ recv_table ,
51+ is_master_weights = False ,
52+ ckpt_quant_stage = "O0" ,
4653):
4754 """Merge the splited param in sharding group."""
4855 global_rank = dist .get_rank ()
4956 for key in list (state_dict .keys ()):
50- if state_dict [key ].numel (). item ( ) == 1 : # for example: beta1, beta2
57+ if int ( state_dict [key ].numel ()) == 1 : # for example: beta1, beta2
5158 continue
5259
5360 static_name = key if is_master_weights else generate_base_static_name (key )[0 ]
@@ -89,10 +96,21 @@ def merge_splited_param(
8996 )
9097 dist .stream .send (tensor , dst = recv_rank )
9198 state_dict .pop (key )
99+
100+ if ckpt_quant_stage != "O0" :
101+ for key in list (state_dict .keys ()):
102+ if int (state_dict [key ].numel ()) == 1 : # for example: beta1, beta2
103+ static_name = key if is_master_weights else generate_base_static_name (key )[0 ]
104+ if static_name in partial_tensor_list :
105+ recv_rank = recv_table [static_name ]
106+ send_info = send_table [static_name ]
107+ if global_rank != recv_rank :
108+ state_dict .pop (key )
109+
92110 return state_dict
93111
94112
95- def gather_splited_param_for_optimizer (optimizer ):
113+ def gather_splited_param_for_optimizer (optimizer , ckpt_quant_stage = "O0" ):
96114 hcg = fleet .get_hybrid_communicate_group ()
97115 sharding_group = hcg .get_sharding_parallel_group ()
98116 global_rank = dist .get_rank ()
@@ -127,7 +145,7 @@ def gather_splited_param_for_optimizer(optimizer):
127145 for key in list (optim_state_dict .keys ()):
128146 static_name , _ = generate_base_static_name (key )
129147 if static_name in param_slice_info .keys ():
130- if optim_state_dict [key ].numel (). item ( ) == 1 : # for example: beta1, beta2
148+ if int ( optim_state_dict [key ].numel ()) == 1 : # for example: beta1, beta2
131149 continue
132150 begin , end = param_slice_info [static_name ]
133151 shape , numel , _ , _ = param_shape_info [static_name ]
@@ -149,13 +167,15 @@ def gather_splited_param_for_optimizer(optimizer):
149167 recv_table [key ] = sharding_ranklist [0 ][0 ] # which sharding_rank to recv the splited tensor
150168 send_table [key ] = [(rank , begin , end ) for rank , begin , end in sharding_ranklist ]
151169
152- merge_splited_param (optim_state_dict , partial_tensor_list , param_shape_info , send_table , recv_table , False )
170+ merge_splited_param (
171+ optim_state_dict , partial_tensor_list , param_shape_info , send_table , recv_table , False , ckpt_quant_stage
172+ )
153173 if master_weights is not None :
154174 merge_splited_param (master_weights , partial_tensor_list , param_shape_info , send_table , recv_table , True )
155175 return optim_state_dict , master_weights
156176
157177
158- def load_unified_optimizer_split_param (args , model , optimizer , resume_from_checkpoint ):
178+ def load_unified_optimizer_split_param (args , model , optimizer , resume_from_checkpoint , ckpt_quant_stage = "O0" ):
159179 returned_optim_state_dict = nested_copy (optimizer .state_dict ())
160180
161181 index_filename , index_filename_master_weights = SAFE_OPTIMIZER_INDEX_NAME , SAFE_MASTER_WEIGHTS_INDEX_NAME
@@ -208,6 +228,10 @@ def load_unified_optimizer_split_param(args, model, optimizer, resume_from_check
208228 if len (resolved_archive_file ) > 1 :
209229 resolved_archive_file = tqdm (resolved_archive_file , desc = "Loading optimizer shards" )
210230
231+ has_master_weights , index_filename_master_weights = update_master_weight_status (
232+ args , optimizer , has_master_weights , safe_serialization = True
233+ )
234+
211235 if has_master_weights :
212236 returned_optim_state_dict ["master_weights" ] = {}
213237 resolved_archive_file_mw , sharded_metadata_mw = get_optimizer_shard_files (
@@ -217,7 +241,9 @@ def load_unified_optimizer_split_param(args, model, optimizer, resume_from_check
217241 if len (resolved_archive_file_mw ) > 1 :
218242 resolved_archive_file_mw = tqdm (resolved_archive_file_mw , desc = "Loading master weights shards" )
219243
220- def load_resolved_archive_file (resolved_archive_file , sharded_metadata , expected_keys , is_master_weights = False ):
244+ def load_resolved_archive_file (
245+ resolved_archive_file , sharded_metadata , expected_keys , is_master_weights = False , ckpt_quant_stage = "O0"
246+ ):
221247 returned_state_dict = {}
222248
223249 if model .config .tensor_parallel_degree > 1 :
@@ -232,24 +258,38 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
232258 if expected_keys .isdisjoint (sharded_metadata ["file_map" ][os .path .split (shard_file )[- 1 ]]):
233259 continue
234260 if model .config .tensor_parallel_degree > 1 :
235- state_dict = load_state_dict (shard_file , tp_actions , expected_keys , device = "cpu" )
261+ state_dict = load_state_dict (
262+ shard_file ,
263+ tp_actions ,
264+ expected_keys ,
265+ device = "cpu" ,
266+ ckpt_quant_stage = ckpt_quant_stage ,
267+ )
236268 else :
237- state_dict = load_state_dict (shard_file , None , expected_keys , device = "cpu" )
269+ state_dict = load_state_dict (
270+ shard_file ,
271+ None ,
272+ expected_keys ,
273+ device = "cpu" ,
274+ ckpt_quant_stage = ckpt_quant_stage ,
275+ )
238276 returned_state_dict .update (state_dict )
239277 del state_dict
240278 gc .collect ()
241279
242280 return returned_state_dict
243281
244282 # get tp params
245- state_dict_optim = load_resolved_archive_file (resolved_archive_file , sharded_metadata , expected_keys_optim )
283+ state_dict_optim = load_resolved_archive_file (
284+ resolved_archive_file , sharded_metadata , expected_keys_optim , ckpt_quant_stage = ckpt_quant_stage
285+ )
246286
247287 # need to split param for different sharding rank, maybe need to deal with oom issue.
248288 for key in list (state_dict_optim .keys ()):
249289 key_name = key .split ("/" )
250290 static_name = struct2static_name_mappings .get (key_name [0 ], None )
251291
252- if state_dict_optim [key ].numel (). item ( ) > 1 :
292+ if int ( state_dict_optim [key ].numel ()) > 1 :
253293 begin , end = param_slice_info [static_name ]
254294 shape , numel , index , padded_size = param_shape_info [static_name ]
255295 state_dict_optim [key ] = state_dict_optim [key ].reshape ([- 1 ])
@@ -284,7 +324,7 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
284324
285325 for key in list (state_dict_master_weight .keys ()):
286326 static_name = struct2static_name_mappings .get (key , None )
287- if state_dict_master_weight [key ].numel (). item ( ) > 1 :
327+ if int ( state_dict_master_weight [key ].numel ()) > 1 :
288328 begin , end = param_slice_info [static_name ]
289329 shape , numel , index , padded_size = param_shape_info [static_name ]
290330 state_dict_master_weight [key ] = state_dict_master_weight [key ].reshape ([- 1 ])
@@ -303,6 +343,13 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
303343 paddle .framework ._current_expected_place (), False
304344 )
305345 returned_optim_state_dict ["master_weights" ][static_name ] = state_dict_master_weight .pop (key )
346+
347+ # master weight cast (only in remove_master_weight)
348+ if returned_optim_state_dict ["master_weights" ][static_name ].dtype != paddle .float32 :
349+ returned_optim_state_dict ["master_weights" ][static_name ] = paddle .cast (
350+ returned_optim_state_dict ["master_weights" ][static_name ], dtype = paddle .float32
351+ )
352+
306353 returned_optim_state_dict ["master_weights" ][static_name ].name = "_" .join ([static_name , FP32_MASTER ])
307354
308355 return returned_optim_state_dict
0 commit comments