From 751eb3f489e64d2a0ed60b89d39214f8bb29d743 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Mon, 12 May 2025 20:34:23 -0700 Subject: [PATCH 1/3] init --- examples/models/qwen3/convert_weights.py | 31 +++++++++++++++--------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/examples/models/qwen3/convert_weights.py b/examples/models/qwen3/convert_weights.py index 53a609885d7..c00b6a32ede 100644 --- a/examples/models/qwen3/convert_weights.py +++ b/examples/models/qwen3/convert_weights.py @@ -13,6 +13,7 @@ _QWEN_3_FROM_META = { "tok_embeddings.weight": "model.embed_tokens.weight", "norm.weight": "model.norm.weight", + "output.weight": "lm_head.weight", "layers.{}.attention.wk.weight": "model.layers.{}.self_attn.k_proj.weight", "layers.{}.attention.k_norm_fn.weight": "model.layers.{}.self_attn.k_norm.weight", "layers.{}.attention.wq.weight": "model.layers.{}.self_attn.q_proj.weight", @@ -47,20 +48,19 @@ def qwen_3_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch. inverted_mapping_dict = {v: k for k, v in _QWEN_3_FROM_META.items()} for key, value in state_dict.items(): - # Tied embeddings for 0.6b and 4b models. - if key == "lm_head.weight": - continue new_key = get_mapped_key(key, inverted_mapping_dict) converted_state_dict[new_key] = value - converted_state_dict["output.weight"] = converted_state_dict[ - "tok_embeddings.weight" - ] + # If lm_head.weight is not present, assume tied embeddings (e.g., 0.6b and 4b models) + if "lm_head.weight" not in state_dict: + converted_state_dict["output.weight"] = converted_state_dict[ + "tok_embeddings.weight" + ] return converted_state_dict -def load_checkpoint(input_dir: str) -> Dict: +def load_checkpoint_from_safetensors(input_dir: str) -> Dict: index_path = os.path.join(input_dir, "model.safetensors.index.json") if os.path.exists(index_path): # Sharded checkpoint. @@ -86,9 +86,18 @@ def load_checkpoint(input_dir: str) -> Dict: return state_dict -def convert_weights(input_dir: str, output_file: str) -> None: +def load_checkpoint(input_dir_or_checkpoint: str) -> Dict: + if os.path.isdir(input_dir_or_checkpoint): + return load_checkpoint_from_safetensors(input_dir_or_checkpoint) + else: + return torch.load( + input_dir_or_checkpoint, map_location="cpu", weights_only=True + ) + + +def convert_weights(input_dir_or_checkpoint: str, output_file: str) -> None: print("Loading checkpoint...") - sd = load_checkpoint(input_dir) + sd = load_checkpoint(input_dir_or_checkpoint) print("Converting checkpoint...") sd = qwen_3_tune_to_meta(sd) print("Saving checkpoint...") @@ -101,9 +110,9 @@ def main(): description="Convert Qwen3 weights to Meta format." ) parser.add_argument( - "input_dir", + "input_dir_or_checkpoint", type=str, - help="Path to directory containing checkpoint files", + help="Path to directory containing safetensor checkpoint files, or path to a PyTorch checkpoint file.", ) parser.add_argument("output", type=str, help="Path to the output checkpoint") From ea7cbdc6d5a2383921747999bac7dee7935d224b Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Tue, 13 May 2025 08:05:19 -0700 Subject: [PATCH 2/3] up --- examples/models/qwen3/convert_weights.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/models/qwen3/convert_weights.py b/examples/models/qwen3/convert_weights.py index c00b6a32ede..ee524b58df2 100644 --- a/examples/models/qwen3/convert_weights.py +++ b/examples/models/qwen3/convert_weights.py @@ -117,7 +117,7 @@ def main(): parser.add_argument("output", type=str, help="Path to the output checkpoint") args = parser.parse_args() - convert_weights(args.input_dir, args.output) + convert_weights(args.input_dir_or_checkpoint, args.output) if __name__ == "__main__": From 880c99b01db5b2c8255233b761cc28d030ed1ecd Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Tue, 13 May 2025 19:23:52 -0700 Subject: [PATCH 3/3] up --- examples/models/qwen3/convert_weights.py | 26 ++++++++++++------------ 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/examples/models/qwen3/convert_weights.py b/examples/models/qwen3/convert_weights.py index ee524b58df2..6d5254906fb 100644 --- a/examples/models/qwen3/convert_weights.py +++ b/examples/models/qwen3/convert_weights.py @@ -51,7 +51,7 @@ def qwen_3_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch. new_key = get_mapped_key(key, inverted_mapping_dict) converted_state_dict[new_key] = value - # If lm_head.weight is not present, assume tied embeddings (e.g., 0.6b and 4b models) + # If lm_head.weight is not present in state dict, assume tied embeddings (e.g., 0.6b and 4b models) if "lm_head.weight" not in state_dict: converted_state_dict["output.weight"] = converted_state_dict[ "tok_embeddings.weight" @@ -86,18 +86,18 @@ def load_checkpoint_from_safetensors(input_dir: str) -> Dict: return state_dict -def load_checkpoint(input_dir_or_checkpoint: str) -> Dict: - if os.path.isdir(input_dir_or_checkpoint): - return load_checkpoint_from_safetensors(input_dir_or_checkpoint) - else: - return torch.load( - input_dir_or_checkpoint, map_location="cpu", weights_only=True - ) +def load_checkpoint(input_dir: str) -> Dict: + pytorch_path = os.path.join(input_dir, "pytorch_model.bin") + if os.path.exists(pytorch_path): + print("Loading checkpoint from PyTorch .bin file") + return torch.load(pytorch_path, map_location="cpu", weights_only=True) + print("Loading checkpoint from safetensors directory") + return load_checkpoint_from_safetensors(input_dir) -def convert_weights(input_dir_or_checkpoint: str, output_file: str) -> None: +def convert_weights(input_dir: str, output_file: str) -> None: print("Loading checkpoint...") - sd = load_checkpoint(input_dir_or_checkpoint) + sd = load_checkpoint(input_dir) print("Converting checkpoint...") sd = qwen_3_tune_to_meta(sd) print("Saving checkpoint...") @@ -110,14 +110,14 @@ def main(): description="Convert Qwen3 weights to Meta format." ) parser.add_argument( - "input_dir_or_checkpoint", + "input_dir", type=str, - help="Path to directory containing safetensor checkpoint files, or path to a PyTorch checkpoint file.", + help="Path to directory containing safetensor checkpoint files, or PyTorch checkpoint file.", ) parser.add_argument("output", type=str, help="Path to the output checkpoint") args = parser.parse_args() - convert_weights(args.input_dir_or_checkpoint, args.output) + convert_weights(args.input_dir, args.output) if __name__ == "__main__":