Skip to content

Commit 6648372

Browse files
committed
Fix output embedding
1 parent 9cc5238 commit 6648372

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

examples/models/qwen2_5/convert_weights.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
_QWEN_2_FROM_META = {
1010
"tok_embeddings.weight": "tok_embeddings.weight",
1111
"norm.weight": "norm.scale",
12-
"output.weight": "output.weight",
1312
"layers.{}.attention.wk.weight": "layers.{}.attn.k_proj.weight",
1413
"layers.{}.attention.wk.bias": "layers.{}.attn.k_proj.bias",
1514
"layers.{}.attention.wq.weight": "layers.{}.attn.q_proj.weight",
@@ -22,7 +21,6 @@
2221
"layers.{}.feed_forward.w1.weight": "layers.{}.mlp.w1.weight",
2322
"layers.{}.feed_forward.w2.weight": "layers.{}.mlp.w2.weight",
2423
"layers.{}.feed_forward.w3.weight": "layers.{}.mlp.w3.weight",
25-
2624
}
2725

2826
def qwen_2_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
@@ -44,6 +42,9 @@ def qwen_2_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.
4442
new_key = get_mapped_key(key, inverted_mapping_dict)
4543
converted_state_dict[new_key] = value
4644

45+
# 0.5b and 1.5b models share the same weights for tok_embeddings and output embeddings, see https://github.com/QwenLM/Qwen2.5/issues/733.
46+
converted_state_dict["output.weight"] = converted_state_dict["tok_embeddings.weight"]
47+
4748
return converted_state_dict
4849

4950
# TODO: no need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves.

0 commit comments

Comments
 (0)