11import argparse
2+ import os
23from typing import Dict
34
45import torch
78
89from torchtune .training import FullModelHFCheckpointer
910
11+ _HF_PHI_4_FROM_META = {
12+ "tok_embeddings.weight" : "model.embed_tokens.weight" ,
13+ "norm.weight" : "model.norm.weight" ,
14+ "layers.{}.attention.wq.weight" : "model.layers.{}.self_attn.q_proj.weight" ,
15+ "layers.{}.attention.wk.weight" : "model.layers.{}.self_attn.k_proj.weight" ,
16+ "layers.{}.attention.wv.weight" : "model.layers.{}.self_attn.v_proj.weight" ,
17+ "layers.{}.attention.wo.weight" : "model.layers.{}.self_attn.o_proj.weight" ,
18+ "layers.{}.attention_norm.weight" : "model.layers.{}.input_layernorm.weight" ,
19+ "layers.{}.ffn_norm.weight" : "model.layers.{}.post_attention_layernorm.weight" ,
20+ "layers.{}.feed_forward.w1.weight" : "model.layers.{}.mlp.gate_proj.weight" ,
21+ "layers.{}.feed_forward.w3.weight" : "model.layers.{}.mlp.up_proj.weight" ,
22+ "layers.{}.feed_forward.w2.weight" : "model.layers.{}.mlp.down_proj.weight" ,
23+ "output.weight" : "lm_head.weight" ,
24+ }
25+
26+
27+ def phi_4_hf_to_meta (state_dict : Dict [str , torch .Tensor ]) -> Dict [str , torch .Tensor ]:
28+ """
29+ Convert a state dict from hf's format to Meta's format.
30+
31+ Args:
32+ state_dict (Dict[str, torch.Tensor]): State dict in hf's format.
33+
34+ Returns:
35+ Dict[str, torch.Tensor]: State dict in Meta's format.
36+ """
37+ converted_state_dict = {}
38+ inverted_mapping_dict = {v : k for k , v in _HF_PHI_4_FROM_META .items ()}
39+
40+ for key , value in state_dict .items ():
41+ if key .endswith ("mlp.gate_up_proj.weight" ):
42+ # Split the gate_up_proj into gate_proj and up_proj
43+ hidden_dim = value .shape [0 ] // 2
44+ assert 2 * hidden_dim == value .shape [0 ]
45+ gate = value [0 :hidden_dim , :]
46+ up = value [hidden_dim :, :]
47+ for new_key , new_value in [("gate_proj" , gate ), ("up_proj" , up )]:
48+ new_key = key .replace ("gate_up_proj" , new_key )
49+ new_key = get_mapped_key (new_key , inverted_mapping_dict )
50+ converted_state_dict [new_key ] = new_value
51+ elif key .endswith ("self_attn.qkv_proj.weight" ):
52+ # Split the qkv_proj into q_proj, k_proj, and v_proj
53+ q_dim = value .shape [1 ]
54+ kv_dim = (value .shape [0 ] - q_dim ) // 2
55+ assert 2 * kv_dim + q_dim == value .shape [0 ]
56+ q = value [0 :q_dim , :]
57+ k = value [q_dim : (q_dim + kv_dim ), :]
58+ v = value [(q_dim + kv_dim ) :, :]
59+ for new_key , new_value in [("q_proj" , q ), ("k_proj" , k ), ("v_proj" , v )]:
60+ new_key = key .replace ("qkv_proj" , new_key )
61+ new_key = get_mapped_key (new_key , inverted_mapping_dict )
62+ converted_state_dict [new_key ] = new_value
63+ else :
64+ new_key = get_mapped_key (key , inverted_mapping_dict )
65+ converted_state_dict [new_key ] = value
66+ return converted_state_dict
67+
1068
1169# Standard _FROM_META weight mapping of Meta weights to TorchTune.
1270_PHI_4_FROM_META = {
@@ -51,22 +109,29 @@ def phi_4_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.T
51109 return converted_state_dict
52110
53111
54- def convert_weights (input_dir : str , output_file : str ) -> None :
112+ def convert_weights (input_dir_or_checkpoint : str , output_file : str ) -> None :
55113 # Don't necessarily need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves.
56- checkpointer = FullModelHFCheckpointer (
57- checkpoint_dir = input_dir ,
58- checkpoint_files = [
59- "model-00001-of-00002.safetensors" ,
60- "model-00002-of-00002.safetensors" ,
61- ],
62- output_dir = "." ,
63- model_type = "PHI4" ,
64- )
114+ if os .path .isdir (input_dir_or_checkpoint ):
115+ checkpointer = FullModelHFCheckpointer (
116+ checkpoint_dir = input_dir_or_checkpoint ,
117+ checkpoint_files = [
118+ "model-00001-of-00002.safetensors" ,
119+ "model-00002-of-00002.safetensors" ,
120+ ],
121+ output_dir = "." ,
122+ model_type = "PHI4" ,
123+ )
124+ print ("Loading checkpoint from directory..." )
125+ sd = checkpointer .load_checkpoint ()
126+ sd = sd ["model" ]
127+ print ("Converting checkpoint..." )
128+ sd = phi_4_tune_to_meta (sd )
129+ else :
130+ print ("Loading checkpoint from file..." )
131+ sd = torch .load (input_dir_or_checkpoint , map_location = "cpu" , weights_only = True )
132+ print ("Converting checkpoint..." )
133+ sd = phi_4_hf_to_meta (sd )
65134
66- print ("Loading checkpoint..." )
67- sd = checkpointer .load_checkpoint ()
68- print ("Converting checkpoint..." )
69- sd = phi_4_tune_to_meta (sd ["model" ])
70135 print ("Saving checkpoint..." )
71136 torch .save (sd , output_file )
72137 print ("Done." )
@@ -79,7 +144,7 @@ def main():
79144 parser .add_argument (
80145 "input_dir" ,
81146 type = str ,
82- help = "Path to directory containing checkpoint files" ,
147+ help = "Path to directory containing checkpoint files, or path to a single checkpoint file. " ,
83148 )
84149 parser .add_argument ("output" , type = str , help = "Path to the output checkpoint" )
85150
0 commit comments