11from typing import Dict
22
3- from torchtune .training import FullModelHFCheckpointer
4- # from torchtune.models import convert_weights
5- from torchtune .models .convert_weights import get_mapped_key
63import torch
74
5+ from torchtune .models .convert_weights import get_mapped_key
6+
7+ from torchtune .training import FullModelHFCheckpointer
8+
89# Standard _FROM_META weight mapping of Meta weights to TorchTune + additional bias weight mappings.
910_QWEN_2_FROM_META = {
1011 "tok_embeddings.weight" : "tok_embeddings.weight" ,
2324 "layers.{}.feed_forward.w3.weight" : "layers.{}.mlp.w3.weight" ,
2425}
2526
27+
2628def qwen_2_tune_to_meta (state_dict : Dict [str , torch .Tensor ]) -> Dict [str , torch .Tensor ]:
2729 """
2830 Convert a state dict from torchtune's format to Meta's format. This function
@@ -43,32 +45,26 @@ def qwen_2_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.
4345 converted_state_dict [new_key ] = value
4446
4547 # 0.5b and 1.5b models share the same weights for tok_embeddings and output embeddings, see https://github.com/QwenLM/Qwen2.5/issues/733.
46- converted_state_dict ["output.weight" ] = converted_state_dict ["tok_embeddings.weight" ]
48+ converted_state_dict ["output.weight" ] = converted_state_dict [
49+ "tok_embeddings.weight"
50+ ]
4751
4852 return converted_state_dict
4953
54+
5055# TODO: no need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves.
5156checkpointer = FullModelHFCheckpointer (
52- checkpoint_dir = ' /home/jackzhxng/.cache/huggingface/hub/models--Qwen--Qwen2.5-1.5B/snapshots/8faed761d45a263340a0528343f099c05c9a4323/' ,
53- checkpoint_files = [' model.safetensors' ],
54- output_dir = '.' ,
55- model_type = ' QWEN2'
57+ checkpoint_dir = " /home/jackzhxng/.cache/huggingface/hub/models--Qwen--Qwen2.5-1.5B/snapshots/8faed761d45a263340a0528343f099c05c9a4323/" ,
58+ checkpoint_files = [" model.safetensors" ],
59+ output_dir = "." ,
60+ model_type = " QWEN2" ,
5661)
5762
5863print ("Loading checkpoint" )
5964sd = checkpointer .load_checkpoint ()
6065
61- print ("HF weights:" )
62- for weight in sd ["model" ].keys ():
63- print (weight )
64- print ()
65-
66- # Convert from TorchTune to Meta (PyTorch native)
67- sd = qwen_2_tune_to_meta (sd ['model' ])
68-
69- print ("Meta weights:" )
70- for weight in sd .keys ():
71- print (weight )
66+ # Convert from TorchTune to Meta (PyTorch native).
67+ sd = qwen_2_tune_to_meta (sd ["model" ])
7268
7369print ("Saving checkpoint" )
74- torch .save (sd , "/home/jackzhxng/models/qwen2_5-1_5b.pth" )
70+ torch .save (sd , "/home/jackzhxng/models/qwen2_5-1_5b.pth" )
0 commit comments