1313_QWEN_3_FROM_META = {
1414 "tok_embeddings.weight" : "model.embed_tokens.weight" ,
1515 "norm.weight" : "model.norm.weight" ,
16+ "output.weight" : "lm_head.weight" ,
1617 "layers.{}.attention.wk.weight" : "model.layers.{}.self_attn.k_proj.weight" ,
1718 "layers.{}.attention.k_norm_fn.weight" : "model.layers.{}.self_attn.k_norm.weight" ,
1819 "layers.{}.attention.wq.weight" : "model.layers.{}.self_attn.q_proj.weight" ,
@@ -47,20 +48,19 @@ def qwen_3_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.
4748 inverted_mapping_dict = {v : k for k , v in _QWEN_3_FROM_META .items ()}
4849
4950 for key , value in state_dict .items ():
50- # Tied embeddings for 0.6b and 4b models.
51- if key == "lm_head.weight" :
52- continue
5351 new_key = get_mapped_key (key , inverted_mapping_dict )
5452 converted_state_dict [new_key ] = value
5553
56- converted_state_dict ["output.weight" ] = converted_state_dict [
57- "tok_embeddings.weight"
58- ]
54+ # If lm_head.weight is not present in state dict, assume tied embeddings (e.g., 0.6b and 4b models)
55+ if "lm_head.weight" not in state_dict :
56+ converted_state_dict ["output.weight" ] = converted_state_dict [
57+ "tok_embeddings.weight"
58+ ]
5959
6060 return converted_state_dict
6161
6262
63- def load_checkpoint (input_dir : str ) -> Dict :
63+ def load_checkpoint_from_safetensors (input_dir : str ) -> Dict :
6464 index_path = os .path .join (input_dir , "model.safetensors.index.json" )
6565 if os .path .exists (index_path ):
6666 # Sharded checkpoint.
@@ -86,6 +86,15 @@ def load_checkpoint(input_dir: str) -> Dict:
8686 return state_dict
8787
8888
89+ def load_checkpoint (input_dir : str ) -> Dict :
90+ pytorch_path = os .path .join (input_dir , "pytorch_model.bin" )
91+ if os .path .exists (pytorch_path ):
92+ print ("Loading checkpoint from PyTorch .bin file" )
93+ return torch .load (pytorch_path , map_location = "cpu" , weights_only = True )
94+ print ("Loading checkpoint from safetensors directory" )
95+ return load_checkpoint_from_safetensors (input_dir )
96+
97+
8998def convert_weights (input_dir : str , output_file : str ) -> None :
9099 print ("Loading checkpoint..." )
91100 sd = load_checkpoint (input_dir )
@@ -103,7 +112,7 @@ def main():
103112 parser .add_argument (
104113 "input_dir" ,
105114 type = str ,
106- help = "Path to directory containing checkpoint files" ,
115+ help = "Path to directory containing safetensor checkpoint files, or PyTorch checkpoint file. " ,
107116 )
108117 parser .add_argument ("output" , type = str , help = "Path to the output checkpoint" )
109118
0 commit comments