Skip to content

Commit 43020ef

Browse files
committed
Comment / lint
1 parent 6648372 commit 43020ef

File tree

2 files changed

+17
-20
lines changed

2 files changed

+17
-20
lines changed

examples/models/llama/rope.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ def qwen_apply_rotary_emb(
233233
"""
234234
Apply Qwen2-style RoPE to query and key tensors.
235235
"""
236+
236237
def rotate_half(x):
237238
"""Rotates half the hidden dims of the input."""
238239
x1 = x[..., : x.shape[-1] // 2]

examples/models/qwen2_5/convert_weights.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from typing import Dict
22

3-
from torchtune.training import FullModelHFCheckpointer
4-
# from torchtune.models import convert_weights
5-
from torchtune.models.convert_weights import get_mapped_key
63
import torch
74

5+
from torchtune.models.convert_weights import get_mapped_key
6+
7+
from torchtune.training import FullModelHFCheckpointer
8+
89
# Standard _FROM_META weight mapping of Meta weights to TorchTune + additional bias weight mappings.
910
_QWEN_2_FROM_META = {
1011
"tok_embeddings.weight": "tok_embeddings.weight",
@@ -23,6 +24,7 @@
2324
"layers.{}.feed_forward.w3.weight": "layers.{}.mlp.w3.weight",
2425
}
2526

27+
2628
def qwen_2_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
2729
"""
2830
Convert a state dict from torchtune's format to Meta's format. This function
@@ -43,32 +45,26 @@ def qwen_2_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.
4345
converted_state_dict[new_key] = value
4446

4547
# 0.5b and 1.5b models share the same weights for tok_embeddings and output embeddings, see https://github.com/QwenLM/Qwen2.5/issues/733.
46-
converted_state_dict["output.weight"] = converted_state_dict["tok_embeddings.weight"]
48+
converted_state_dict["output.weight"] = converted_state_dict[
49+
"tok_embeddings.weight"
50+
]
4751

4852
return converted_state_dict
4953

54+
5055
# TODO: no need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves.
5156
checkpointer = FullModelHFCheckpointer(
52-
checkpoint_dir='/home/jackzhxng/.cache/huggingface/hub/models--Qwen--Qwen2.5-1.5B/snapshots/8faed761d45a263340a0528343f099c05c9a4323/',
53-
checkpoint_files=['model.safetensors'],
54-
output_dir='.' ,
55-
model_type='QWEN2'
57+
checkpoint_dir="/home/jackzhxng/.cache/huggingface/hub/models--Qwen--Qwen2.5-1.5B/snapshots/8faed761d45a263340a0528343f099c05c9a4323/",
58+
checkpoint_files=["model.safetensors"],
59+
output_dir=".",
60+
model_type="QWEN2",
5661
)
5762

5863
print("Loading checkpoint")
5964
sd = checkpointer.load_checkpoint()
6065

61-
print("HF weights:")
62-
for weight in sd["model"].keys():
63-
print(weight)
64-
print()
65-
66-
# Convert from TorchTune to Meta (PyTorch native)
67-
sd = qwen_2_tune_to_meta(sd['model'])
68-
69-
print("Meta weights:")
70-
for weight in sd.keys():
71-
print(weight)
66+
# Convert from TorchTune to Meta (PyTorch native).
67+
sd = qwen_2_tune_to_meta(sd["model"])
7268

7369
print("Saving checkpoint")
74-
torch.save(sd, "/home/jackzhxng/models/qwen2_5-1_5b.pth")
70+
torch.save(sd, "/home/jackzhxng/models/qwen2_5-1_5b.pth")

0 commit comments

Comments
 (0)