99
1010from torchtune .models .convert_weights import get_mapped_key
1111
12- _LFM_2_FROM_META = {
13- "tok_embeddings. weight" : "model.embed_tokens .weight" ,
14- "norm. weight" : "model.embedding_norm .weight" ,
12+ _LFM_2_TO_META = {
13+ "model.embed_tokens. weight" : "tok_embeddings .weight" ,
14+ "model.embedding_norm. weight" : "norm .weight" ,
1515
16- "layers.{}.attention.wk .weight" : "model. layers.{}.self_attn.k_proj .weight" ,
17- "layers.{}.attention.wq .weight" : "model. layers.{}.self_attn.q_proj .weight" ,
18- "layers.{}.attention.wv .weight" : "model. layers.{}.self_attn.v_proj .weight" ,
19- "layers.{}.attention.wo .weight" : "model. layers.{}.self_attn.out_proj .weight" ,
20- "layers.{}.attention.k_norm_fn .weight" : "model. layers.{}.self_attn.k_layernorm .weight" ,
21- "layers.{}.attention.q_norm_fn .weight" : "model. layers.{}.self_attn.q_layernorm .weight" ,
16+ "model. layers.{}.self_attn.k_proj .weight" : "layers.{}.attention.wk .weight" ,
17+ "model. layers.{}.self_attn.q_proj .weight" : "layers.{}.attention.wq .weight" ,
18+ "model. layers.{}.self_attn.v_proj .weight" : "layers.{}.attention.wv .weight" ,
19+ "model. layers.{}.self_attn.out_proj .weight" : "layers.{}.attention.wo .weight" ,
20+ "model. layers.{}.self_attn.k_layernorm .weight" : "layers.{}.attention.k_norm_fn .weight" ,
21+ "model. layers.{}.self_attn.q_layernorm .weight" : "layers.{}.attention.q_norm_fn .weight" ,
2222
23- "layers.{}.ffn_norm .weight" : "model. layers.{}.post_attention_layernorm .weight" ,
23+ "model. layers.{}.post_attention_layernorm .weight" : "layers.{}.ffn_norm .weight" ,
2424
25- "layers.{}.attention_norm .weight" : "model. layers.{}.operator_norm .weight" ,
25+ "model. layers.{}.operator_norm .weight" : "layers.{}.attention_norm .weight" ,
2626}
2727
2828
29- def lfm_2_tune_to_meta (state_dict : Dict [str , torch .Tensor ]) -> Dict [str , torch .Tensor ]:
29+ def lfm_2_to_meta (state_dict : Dict [str , torch .Tensor ]) -> Dict [str , torch .Tensor ]:
3030 """
31- Convert a state dict from torchtune's format to Meta's format. This function
31+ Convert a state dict from LFM2 HF format to Meta's format. This function
3232 doesn't handle any sharding or splitting of state dicts. It follows the
3333 state_dict IN -> state_dict OUT pattern.
3434
3535 Args:
36- state_dict (Dict[str, torch.Tensor]): State dict in torchtune's format.
36+ state_dict (Dict[str, torch.Tensor]): State dict in LFM2 HF format.
3737
3838 Returns:
3939 Dict[str, torch.Tensor]: State dict in Meta's format.
4040 """
4141 converted_state_dict = {}
42- inverted_mapping_dict = {v : k for k , v in _LFM_2_FROM_META .items ()}
4342
4443 for key , value in state_dict .items ():
4544 try :
46- new_key = get_mapped_key (key , inverted_mapping_dict )
45+ new_key = get_mapped_key (key , _LFM_2_TO_META )
4746 except :
4847 new_key = key .removeprefix ("model." )
4948
@@ -54,7 +53,7 @@ def lfm_2_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.T
5453 else :
5554 converted_state_dict [new_key ] = value
5655
57- # If lm_head.weight is not present in state dict, assume tied embeddings (e.g., 0.6b and 4b models)
56+ # If lm_head.weight is not present in state dict, assume tied embeddings
5857 if "lm_head.weight" not in state_dict :
5958 converted_state_dict ["output.weight" ] = converted_state_dict [
6059 "tok_embeddings.weight"
@@ -73,7 +72,7 @@ def convert_weights(input_dir: str, output_file: str) -> None:
7372 print ("Loading checkpoint..." )
7473 sd = load_checkpoint (input_dir )
7574 print ("Converting checkpoint..." )
76- sd = lfm_2_tune_to_meta (sd )
75+ sd = lfm_2_to_meta (sd )
7776 print ("Saving checkpoint..." )
7877 torch .save (sd , output_file )
7978 print ("Done." )
0 commit comments