diff --git a/examples/models/qwen3/convert_weights.py b/examples/models/qwen3/convert_weights.py index 53a609885d7..6d5254906fb 100644 --- a/examples/models/qwen3/convert_weights.py +++ b/examples/models/qwen3/convert_weights.py @@ -13,6 +13,7 @@ _QWEN_3_FROM_META = { "tok_embeddings.weight": "model.embed_tokens.weight", "norm.weight": "model.norm.weight", + "output.weight": "lm_head.weight", "layers.{}.attention.wk.weight": "model.layers.{}.self_attn.k_proj.weight", "layers.{}.attention.k_norm_fn.weight": "model.layers.{}.self_attn.k_norm.weight", "layers.{}.attention.wq.weight": "model.layers.{}.self_attn.q_proj.weight", @@ -47,20 +48,19 @@ def qwen_3_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch. inverted_mapping_dict = {v: k for k, v in _QWEN_3_FROM_META.items()} for key, value in state_dict.items(): - # Tied embeddings for 0.6b and 4b models. - if key == "lm_head.weight": - continue new_key = get_mapped_key(key, inverted_mapping_dict) converted_state_dict[new_key] = value - converted_state_dict["output.weight"] = converted_state_dict[ - "tok_embeddings.weight" - ] + # If lm_head.weight is not present in state dict, assume tied embeddings (e.g., 0.6b and 4b models) + if "lm_head.weight" not in state_dict: + converted_state_dict["output.weight"] = converted_state_dict[ + "tok_embeddings.weight" + ] return converted_state_dict -def load_checkpoint(input_dir: str) -> Dict: +def load_checkpoint_from_safetensors(input_dir: str) -> Dict: index_path = os.path.join(input_dir, "model.safetensors.index.json") if os.path.exists(index_path): # Sharded checkpoint. @@ -86,6 +86,15 @@ def load_checkpoint(input_dir: str) -> Dict: return state_dict +def load_checkpoint(input_dir: str) -> Dict: + pytorch_path = os.path.join(input_dir, "pytorch_model.bin") + if os.path.exists(pytorch_path): + print("Loading checkpoint from PyTorch .bin file") + return torch.load(pytorch_path, map_location="cpu", weights_only=True) + print("Loading checkpoint from safetensors directory") + return load_checkpoint_from_safetensors(input_dir) + + def convert_weights(input_dir: str, output_file: str) -> None: print("Loading checkpoint...") sd = load_checkpoint(input_dir) @@ -103,7 +112,7 @@ def main(): parser.add_argument( "input_dir", type=str, - help="Path to directory containing checkpoint files", + help="Path to directory containing safetensor checkpoint files, or PyTorch checkpoint file.", ) parser.add_argument("output", type=str, help="Path to the output checkpoint")