@@ -238,156 +238,129 @@ def process_one_shard(shard_dir):
238238 if args .test :
239239 ref_state_dict = load_file (os .path .join (args .test_hf_dir , 'model.safetensors' ))
240240
241- def handle_qkv_proj (key , config , tensor , state_dict ):
242- nonlocal tp_size
243-
244- hidden_size_per_head = config .hidden_size // config .num_attention_heads
245-
246- if config .num_key_value_heads >= tp_size :
247- q_size_tp = config .hidden_size // tp_size
248- kv_size_tp = hidden_size_per_head * config .num_key_value_heads // tp_size
249- total_size = q_size_tp + 2 * kv_size_tp
250- q_part = tensor [:q_size_tp ]
251- k_part = tensor [q_size_tp :q_size_tp + kv_size_tp ]
252- v_part = tensor [q_size_tp + kv_size_tp :total_size ]
253- else :
254- q_size_tp = config .hidden_size // tp_size
255- kv_size_tp = hidden_size_per_head
256- total_size = q_size_tp + 2 * kv_size_tp
257- q_part = tensor [:q_size_tp ]
258- k_part = tensor [q_size_tp :q_size_tp + kv_size_tp ]
259- v_part = tensor [q_size_tp + kv_size_tp :total_size ]
260-
261- preffix = '.' .join (key .split ('.' )[:4 ])
262- suffix = '.' .join (key .split ('.' )[5 :])
263- if state_dict .get (f'{ preffix } .q_proj.{ suffix } ' ) is None :
264- state_dict [f'{ preffix } .q_proj.{ suffix } ' ] = q_part
265- else :
266- state_dict [f'{ preffix } .q_proj.{ suffix } ' ] = torch .concat ([state_dict [f'{ preffix } .q_proj.{ suffix } ' ], q_part ], dim = 0 )
267- if state_dict .get (f'{ preffix } .k_proj.{ suffix } ' ) is None :
268- state_dict [f'{ preffix } .k_proj.{ suffix } ' ] = k_part
269- else :
270- state_dict [f'{ preffix } .k_proj.{ suffix } ' ] = torch .concat ([state_dict [f'{ preffix } .k_proj.{ suffix } ' ], k_part ], dim = 0 )
271- if state_dict .get (f'{ preffix } .v_proj.{ suffix } ' ) is None :
272- state_dict [f'{ preffix } .v_proj.{ suffix } ' ] = v_part
241+ def merge_across_tp (key , tp_data ):
242+ if "linear_fc1.weight" in key :
243+ # if the tensor is gate and proj
244+ gate_lst = []
245+ up_lst = []
246+ for infer_param in tp_data :
247+ gate , up = infer_param .chunk (2 )
248+ gate_lst .append (gate )
249+ up_lst .append (up )
250+ gate = torch .cat (gate_lst , dim = 0 )
251+ up = torch .cat (up_lst , dim = 0 )
252+ tp_data = [gate , up ]
253+ elif "self_attention.linear_qkv." in key and 'layer_norm' not in key :
254+ # if the tensor is qkv, for each param on tp, split into q, k, v
255+ # concat q, k, v separately.
256+ q_lst = []
257+ k_lst = []
258+ v_lst = []
259+ assert config .num_attention_heads % config .num_key_value_heads == 0
260+ num_q_per_kv = config .num_attention_heads // config .num_key_value_heads
261+ assert tp_data [0 ].shape [0 ] % (num_q_per_kv + 2 ) == 0
262+ kv_size_per_tp = tp_data [0 ].shape [0 ] // (num_q_per_kv + 2 )
263+ split_size = [kv_size_per_tp * num_q_per_kv , kv_size_per_tp , kv_size_per_tp ]
264+ for infer_param in tp_data :
265+ num_query_groups_per_partition = config .num_key_value_heads // tp_size
266+ for chunk in infer_param .chunk (num_query_groups_per_partition ):
267+ split_size = [
268+ kv_size_per_tp * num_q_per_kv // num_query_groups_per_partition ,
269+ kv_size_per_tp // num_query_groups_per_partition ,
270+ kv_size_per_tp // num_query_groups_per_partition
271+ ]
272+ q , k , v = chunk .split (split_size )
273+ q_lst .append (q )
274+ k_lst .append (k )
275+ v_lst .append (v )
276+ q = torch .cat (q_lst , dim = 0 )
277+ k = torch .cat (k_lst , dim = 0 )
278+ v = torch .cat (v_lst , dim = 0 )
279+
280+ tp_data = [q ,k ,v ]
281+
282+ elif "layer_norm" in key or "layernorm" in key or "output_layer" in key and args .is_value_model :
283+ tp_data = tp_data [0 ]
273284 else :
274- state_dict [f'{ preffix } .v_proj.{ suffix } ' ] = torch .concat ([state_dict [f'{ preffix } .v_proj.{ suffix } ' ], v_part ], dim = 0 )
285+ dim = 0
286+ if "linear_fc2.weight" in key or "self_attention.linear_proj" in key :
287+ dim = 1
288+ tp_data = torch .cat (tp_data , dim = dim )
275289
276- return state_dict
277290
278- def handle_gate_up_proj (key , config , tensor , state_dict ):
279- nonlocal tp_size
280-
281- intermediate_size_tp = config .intermediate_size // tp_size
282- gate_weight_tp = tensor [:intermediate_size_tp ]
283- up_weight_tp = tensor [intermediate_size_tp :]
284- preffix = '.' .join (key .split ('.' )[:4 ])
285- suffix = '.' .join (key .split ('.' )[5 :])
286- if state_dict .get (f'{ preffix } .gate_proj.{ suffix } ' ) is None :
287- state_dict [f'{ preffix } .gate_proj.{ suffix } ' ] = gate_weight_tp
288- else :
289- state_dict [f'{ preffix } .gate_proj.{ suffix } ' ] = torch .concat ([state_dict [f'{ preffix } .gate_proj.{ suffix } ' ], gate_weight_tp ], dim = 0 )
290- if state_dict .get (f'{ preffix } .up_proj.{ suffix } ' ) is None :
291- state_dict [f'{ preffix } .up_proj.{ suffix } ' ] = up_weight_tp
292- else :
293- state_dict [f'{ preffix } .up_proj.{ suffix } ' ] = torch .concat ([state_dict [f'{ preffix } .up_proj.{ suffix } ' ], up_weight_tp ], dim = 0 )
294-
295- return state_dict
296-
297- def merge_between_tp_rank (key , model_state_dict ):
298- nonlocal state_dict
299-
300- try :
301- tensor = model_state_dict .pop (key )
302- except :
303- raise RuntimeError (f"error pop: { key } " )
304- # Embedding layer
305- if "model.embed_tokens.weight" in key :
306- if state_dict [key ] is None :
307- state_dict [key ] = tensor
308- else :
309- state_dict [key ] = torch .concat ([state_dict [key ], tensor ], dim = 0 )
310- return state_dict
311- # Tranformer Layers
312- if "input_layernorm.weight" in key :
313- if state_dict [key ] is None :
314- state_dict [key ] = tensor
315- return state_dict
316- if re .search (r"self_attn\.qkv_proj" , key ):
317- state_dict = handle_qkv_proj (key , config , tensor , state_dict )
318- return state_dict
319- if "self_attn.o_proj.weight" in key :
320- if state_dict [key ] is None :
321- state_dict [key ] = tensor
322- else :
323- state_dict [key ] = torch .concat ([state_dict [key ], tensor ], dim = 1 )
324- return state_dict
325- if "post_attention_layernorm.weight" in key :
326- if state_dict [key ] is None :
327- state_dict [key ] = tensor
328- return state_dict
329- if re .search (r"mlp\.gate_up_proj\.weight" , key ):
330- state_dict = handle_gate_up_proj (key , config , tensor , state_dict )
331- return state_dict
332- if "mlp.down_proj.weight" in key :
333- if state_dict [key ] is None :
334- state_dict [key ] = tensor
335- else :
336- state_dict [key ] = torch .concat ([state_dict [key ], tensor ], dim = 1 )
337- return state_dict
338- # Final LayerNorm
339- if "model.norm.weight" in key :
340- if state_dict [key ] is None :
341- state_dict [key ] = tensor
342- return state_dict
343- if not args .tie_word_embedding :
344- if args .is_value_model :
345- if "lm_head.weight" in key :
346- if state_dict [key ] is None :
347- state_dict [key ] = tensor
348- if "reward_head.weight" in key :
349- if state_dict [key ] is None :
350- state_dict [key ] = tensor
351- else :
352- if "lm_head.weight" in key :
353- if state_dict [key ] is None :
354- state_dict [key ] = tensor
355- else :
356- state_dict [key ] = torch .concat ([state_dict [key ], tensor ], dim = 0 )
357- return state_dict
358- return state_dict
291+ return tp_data
359292
360- for pp_rank in range (pp_size ):
361- print (f'pp_rank: { pp_rank } ' )
362- for vpp_rank , state_dict_single_layer in enumerate (model_state_dict_lst [pp_rank ][0 ]):
363- state_dict_single_layer_iter = state_dict_single_layer .copy ()
364- keys = state_dict_single_layer_iter .keys ()
293+ vpp_size = len (model_state_dict_lst [0 ][0 ])
294+ layers_cum = 0
295+ for vpp_rank in range (vpp_size ):
296+ for pp_rank in range (pp_size ):
297+ layers_handled = 0
298+ keys = model_state_dict_lst [pp_rank ][0 ][vpp_rank ].keys ()
365299 for key in keys :
366300 if "extra_state" in key :
367301 continue
368- if args .tie_word_embedding and ("lm_head" in key or "reward_head " in key ):
302+ if args .tie_word_embedding and ("output_layer " in key ):
369303 print (f'skip lm_head and reward_head loading because of tie_word_embeddings' )
370304 continue
371- if re .search (r"self_attn\.qkv_proj" , key ) is None and re .search (r"gate_up_proj" , key ) is None :
372- state_dict [key ] = None
373- for tp_rank in range (tp_size ):
374- model_state_dict = model_state_dict_lst [pp_rank ][tp_rank ][vpp_rank ]
375- state_dict = merge_between_tp_rank (key , model_state_dict )
376-
305+ new_key = key
306+ if "decoder.layers." in key :
307+ local_layer_no = int (key .split ('.' )[2 ])
308+ layers_handled = max (local_layer_no , layers_handled )
309+ global_layer_no = local_layer_no + layers_cum
310+ new_key_list = key .split ('.' )
311+ new_key_list [2 ] = str (global_layer_no )
312+ new_key = '.' .join (new_key_list )
313+
314+ tp_data = [model_state_dict_lst [pp_rank ][tp_rank ][vpp_rank ][key ] for tp_rank in range (tp_size )]
315+ merged = merge_across_tp (new_key , tp_data )
316+ if not isinstance (merged ,list ):
317+ state_dict [new_key ] = merged
318+ elif len (merged )== 3 :
319+ # split qkv
320+ for n ,d in zip (['q' ,'k' ,'v' ], merged ):
321+ state_dict [new_key .replace ("linear_qkv" ,f"linear_{ n } " )] = d
322+ elif len (merged )== 2 :
323+ # split gate up
324+ state_dict [new_key .replace ("linear_fc1" ,"gate_proj" )] = merged [0 ]
325+ state_dict [new_key .replace ("linear_fc1" ,"up_proj" )] = merged [1 ]
326+ layers_cum += layers_handled + 1 # zero based
327+
377328 del model_state_dict_lst
329+
330+ params_mapping = [
331+ # (megatron core gpt model name, vllm model name)
332+ ("self_attention.linear_qkv.layer_norm_weight" , "input_layernorm.weight" ),
333+ ("self_attention.linear_qkv.layer_norm_bias" , "input_layernorm.bias" ),
334+ ("embedding.word_embeddings" , "model.embed_tokens" ),
335+ ("self_attention.linear_qkv" , "self_attn.qkv_proj" ),
336+ ("self_attention.linear_proj" , "self_attn.o_proj" ),
337+ ("pre_mlp_layernorm" , "post_attention_layernorm" ),
338+ ("mlp.linear_fc1.layer_norm_weight" , "post_attention_layernorm.weight" ),
339+ ("mlp.linear_fc1.layer_norm_bias" , "post_attention_layernorm.bias" ),
340+ ("mlp.linear_fc1" , "mlp.gate_up_proj" ),
341+ ("mlp.linear_fc2" , "mlp.down_proj" ),
342+ ("decoder.final_layernorm" , "model.norm" ),
343+ ("output_layer" , "lm_head" ),
344+ ("self_attention.linear_q" , "self_attn.q_proj" ),
345+ ("self_attention.linear_k" , "self_attn.k_proj" ),
346+ ("self_attention.linear_v" , "self_attn.v_proj" ),
347+ ]
348+
378349 if args .test :
379- for key , value in state_dict .items ():
380- print (key )
381- if key not in ref_state_dict :
382- raise RuntimeError (f'key: { key } not exist in ref_state_dict { value } ' )
383- if value .shape != ref_state_dict [key ].shape :
384- raise RuntimeError (f'key: { key } shape mismatch { value .shape } , { ref_state_dict [key ].shape } ' )
385- assert value .dtype == ref_state_dict [key ].dtype , f'{ key } state_dict[key].dtype: { value .dtype } != ref_state_dict[key].dtype: { ref_state_dict [key ].dtype } '
386- torch .testing .assert_close (value , ref_state_dict [key ], atol = 1e-4 , rtol = 1e-4 )
387- for key in ref_state_dict :
388- if key not in state_dict :
389- raise RuntimeError (f'key: { key } not exist in state_dict { ref_state_dict [key ]} ' )
390-
350+
351+ for original_name , loaded_weight in state_dict .items ():
352+ name = _replace_name (original_name , params_mapping )
353+ if not name or name .endswith (".bias" ) and name not in ref_state_dict :
354+ continue
355+ if "rotary_emb.inv_freq" in name :
356+ continue
357+ if args .tie_word_embedding and "lm_head.weight" in name :
358+ continue
359+ if name not in ref_state_dict :
360+ raise RuntimeError (f'key: { name } not exist in state_dict' )
361+ param = ref_state_dict [name ]
362+ assert loaded_weight .dtype == param .dtype
363+ torch .testing .assert_close (loaded_weight , param , atol = 1e-4 , rtol = 1e-4 )
391364
392365 print ('Writing to local disk' )
393366 if args .target_dir is None :
@@ -415,6 +388,29 @@ def merge_between_tp_rank(key, model_state_dict):
415388 if args .hf_upload_path :
416389 upload_model_to_huggingface (hf_path )
417390
391+
392+ def _replace_name (megatron_name , name_mapping ):
393+ for m_name , v_name in name_mapping :
394+ if m_name not in megatron_name :
395+ continue
396+ if "layers" in megatron_name : # deal with decoder layers
397+ megatron_name = megatron_name .replace ("decoder" , "model" )
398+ megatron_name_list = megatron_name .split ("." )
399+ if "layer_norm_weight" in megatron_name_list or "layer_norm_bias" in megatron_name_list :
400+ param_name_list = megatron_name_list [:3 ]
401+ param_name_list .append (v_name )
402+ param_name = "." .join (param_name_list )
403+ else :
404+ param_name_list = megatron_name_list [:3 ]
405+ weight_or_bias = megatron_name_list [- 1 ]
406+ param_name_list .append (v_name )
407+ param_name_list .append (weight_or_bias )
408+ param_name = "." .join (param_name_list )
409+ return param_name
410+ else :
411+ param_name = megatron_name .replace (m_name , v_name )
412+ return param_name
413+
418414if __name__ == '__main__' :
419415 if args .backend == "fsdp" :
420416 convert_fsdp_checkpoints_to_hfmodels ()
0 commit comments