@@ -51,37 +51,40 @@ def phi_4_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.T
5151 return converted_state_dict
5252
5353
54- def main ():
55- parser = argparse .ArgumentParser (
56- description = "Convert Phi-4-mini weights to Meta format."
57- )
58- parser .add_argument (
59- "input_dir" ,
60- type = str ,
61- help = "Path to directory containing checkpoint files" ,
62- )
63- parser .add_argument ("output" , type = str , help = "Path to the output checkpoint" )
64-
65- args = parser .parse_args ()
66-
54+ def convert_weights (input_dir : str , output_file : str ) -> None :
55+ # Don't necessarily need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves.
6756 checkpointer = FullModelHFCheckpointer (
68- checkpoint_dir = args . input_dir ,
57+ checkpoint_dir = input_dir ,
6958 checkpoint_files = [
7059 "model-00001-of-00002.safetensors" ,
7160 "model-00002-of-00002.safetensors" ,
7261 ],
7362 output_dir = "." ,
74- model_type = "PHI3_MINI " ,
63+ model_type = "PHI4 " ,
7564 )
7665
7766 print ("Loading checkpoint..." )
7867 sd = checkpointer .load_checkpoint ()
79-
8068 print ("Converting checkpoint..." )
8169 sd = phi_4_tune_to_meta (sd ["model" ])
70+ print ("Saving checkpoint..." )
71+ torch .save (sd , output_file )
72+ print ("Done." )
8273
83- torch .save (sd , args .output )
84- print (f"Checkpoint saved to { args .output } " )
74+
75+ def main ():
76+ parser = argparse .ArgumentParser (
77+ description = "Convert Phi-4-mini weights to Meta format."
78+ )
79+ parser .add_argument (
80+ "input_dir" ,
81+ type = str ,
82+ help = "Path to directory containing checkpoint files" ,
83+ )
84+ parser .add_argument ("output" , type = str , help = "Path to the output checkpoint" )
85+
86+ args = parser .parse_args ()
87+ convert_weights (args .input_dir , args .output )
8588
8689
8790if __name__ == "__main__" :
0 commit comments