Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions examples/models/qwen3/convert_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
_QWEN_3_FROM_META = {
"tok_embeddings.weight": "model.embed_tokens.weight",
"norm.weight": "model.norm.weight",
"output.weight": "lm_head.weight",
"layers.{}.attention.wk.weight": "model.layers.{}.self_attn.k_proj.weight",
"layers.{}.attention.k_norm_fn.weight": "model.layers.{}.self_attn.k_norm.weight",
"layers.{}.attention.wq.weight": "model.layers.{}.self_attn.q_proj.weight",
Expand Down Expand Up @@ -47,20 +48,19 @@ def qwen_3_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.
inverted_mapping_dict = {v: k for k, v in _QWEN_3_FROM_META.items()}

for key, value in state_dict.items():
# Tied embeddings for 0.6b and 4b models.
if key == "lm_head.weight":
continue
new_key = get_mapped_key(key, inverted_mapping_dict)
converted_state_dict[new_key] = value

converted_state_dict["output.weight"] = converted_state_dict[
"tok_embeddings.weight"
]
# If lm_head.weight is not present in state dict, assume tied embeddings (e.g., 0.6b and 4b models)
if "lm_head.weight" not in state_dict:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lm_head is present in the hf checkpoints even if they are tied embeddings, it will just be the same weights as the tok_embeddings

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, are you sure it's there? I thought when config. tie_word_embeddings = true, it might not be there, but gets materialized during a tie_weights() command on the HF model.

In any case, if it is there, it's covered by the regular loop through keys and this logic is not executed. If it's not there, this sets lm_head's weight to the embeddings.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it's here, https://huggingface.co/Qwen/Qwen3-0.6B/tree/main?show_file_info=model.safetensors. Also I remember seeing it while debugging the checkpoint. But sure, in that case could you reword the comment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reworded the comment a little.

I'm not convinced by https://huggingface.co/Qwen/Qwen3-0.6B/tree/main?show_file_info=model.safetensors because it's just metadata, and doesn't prove anything about what is stored in the file.

It does not look like lm_head is present in the safetensors when I unpack them locally. Perhaps you looked at the checkpoint after running your script? (Which copied the embedding tensors into lm_head).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, if you double checked then that's fine!

converted_state_dict["output.weight"] = converted_state_dict[
"tok_embeddings.weight"
]

return converted_state_dict


def load_checkpoint(input_dir: str) -> Dict:
def load_checkpoint_from_safetensors(input_dir: str) -> Dict:
index_path = os.path.join(input_dir, "model.safetensors.index.json")
if os.path.exists(index_path):
# Sharded checkpoint.
Expand All @@ -86,6 +86,15 @@ def load_checkpoint(input_dir: str) -> Dict:
return state_dict


def load_checkpoint(input_dir: str) -> Dict:
pytorch_path = os.path.join(input_dir, "pytorch_model.bin")
if os.path.exists(pytorch_path):
print("Loading checkpoint from PyTorch .bin file")
return torch.load(pytorch_path, map_location="cpu", weights_only=True)
print("Loading checkpoint from safetensors directory")
return load_checkpoint_from_safetensors(input_dir)


def convert_weights(input_dir: str, output_file: str) -> None:
print("Loading checkpoint...")
sd = load_checkpoint(input_dir)
Expand All @@ -103,7 +112,7 @@ def main():
parser.add_argument(
"input_dir",
type=str,
help="Path to directory containing checkpoint files",
help="Path to directory containing safetensor checkpoint files, or PyTorch checkpoint file.",
)
parser.add_argument("output", type=str, help="Path to the output checkpoint")

Expand Down
Loading