1111_SMOLLM_FROM_META = {
1212 "tok_embeddings.weight" : "tok_embeddings.weight" ,
1313 "norm.weight" : "norm.scale" ,
14- "output.weight" : "output.weight" ,
1514 "layers.{}.attention.wk.weight" : "layers.{}.attn.k_proj.weight" ,
1615 "layers.{}.attention.wq.weight" : "layers.{}.attn.q_proj.weight" ,
1716 "layers.{}.attention.wv.weight" : "layers.{}.attn.v_proj.weight" ,
@@ -41,10 +40,31 @@ def smollm_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.
4140 for key , value in state_dict .items ():
4241 new_key = get_mapped_key (key , inverted_mapping_dict )
4342 converted_state_dict [new_key ] = value
43+ converted_state_dict ["output.weight" ] = converted_state_dict [
44+ "tok_embeddings.weight"
45+ ]
4446
4547 return converted_state_dict
4648
4749
50+ def convert_weights (input_dir : str , output_file : str ) -> None :
51+ # Don't necessarily need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves.
52+ checkpointer = FullModelHFCheckpointer (
53+ checkpoint_dir = input_dir ,
54+ checkpoint_files = ["model.safetensors" ],
55+ output_dir = "." ,
56+ model_type = "LLAMA3" ,
57+ )
58+
59+ print ("Loading checkpoint..." )
60+ sd = checkpointer .load_checkpoint ()
61+ print ("Converting checkpoint..." )
62+ sd = smollm_tune_to_meta (sd ["model" ])
63+ print ("Saving checkpoint..." )
64+ torch .save (sd , output_file )
65+ print ("Done." )
66+
67+
4868def main ():
4969 parser = argparse .ArgumentParser (
5070 description = "Convert SmolLM weights to Meta format."
@@ -57,23 +77,7 @@ def main():
5777 parser .add_argument ("output" , type = str , help = "Path to the output checkpoint" )
5878
5979 args = parser .parse_args ()
60-
61- # Don't necessarily need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves.
62- checkpointer = FullModelHFCheckpointer (
63- checkpoint_dir = args .input_dir ,
64- checkpoint_files = ["model.safetensors" ],
65- output_dir = "." ,
66- model_type = "LLAMA" ,
67- )
68-
69- print ("Loading checkpoint..." )
70- sd = checkpointer .load_checkpoint ()
71-
72- print ("Converting checkpoint..." )
73- sd = smollm_tune_to_meta (sd ["model" ])
74-
75- torch .save (sd , args .output )
76- print (f"Checkpoint saved to { args .output } " )
80+ convert_weights (args .input_dir , args .output )
7781
7882
7983if __name__ == "__main__" :
0 commit comments