Skip to content

Commit 4c3b28a

Browse files
author
fengyu05
committed
add llama transfer script
1 parent 83cb7eb commit 4c3b28a

File tree

3 files changed

+388
-0
lines changed

3 files changed

+388
-0
lines changed
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import os
2+
import re
3+
from pathlib import Path
4+
from typing import Optional
5+
from collections import OrderedDict
6+
7+
import torch
8+
from tqdm.auto import tqdm
9+
from transformers import LlamaForCausalLM, AutoTokenizer
10+
11+
12+
scale2emb = {
13+
'7B': 4096,
14+
'13B': 5120,
15+
'30B': 6656,
16+
'65B': 8192,
17+
'70B': 8192,
18+
}
19+
20+
21+
key_to_dim = {
22+
"w1": 0,
23+
"w2": -1,
24+
"w3": 0,
25+
"wo": -1,
26+
"wq": 0,
27+
"wk": 0,
28+
"wv": 0,
29+
"output": 0,
30+
"tok_embeddings": -1,
31+
"ffn_norm": None,
32+
"attention_norm": None,
33+
"norm": None,
34+
"rope": None,
35+
}
36+
37+
38+
def init_merged_ckpt(pth_00, num_pth=8, emb_dim=8192):
39+
merged_ckpt = OrderedDict()
40+
for parameter_name, parameter in pth_00.items():
41+
short_name = parameter_name.split(".")[-2]
42+
if key_to_dim[short_name] is None:
43+
merged_ckpt[parameter_name] = parameter
44+
del parameter
45+
elif key_to_dim[short_name] == 0:
46+
size = parameter.shape[0]
47+
merged_param_shape = [ parameter.shape[0] * num_pth, parameter.shape[1] ]
48+
merged_ckpt[parameter_name] = torch.zeros(merged_param_shape)
49+
merged_ckpt[parameter_name][0 : size, :] = parameter
50+
del parameter
51+
elif key_to_dim[short_name] == -1:
52+
size = parameter.shape[-1]
53+
merged_param_shape = [ parameter.shape[0], parameter.shape[1] * num_pth]
54+
merged_ckpt[parameter_name] = torch.zeros(merged_param_shape)
55+
merged_ckpt[parameter_name][:, 0 : size] = parameter
56+
del parameter
57+
return merged_ckpt
58+
59+
60+
def merge_meta_llama(size: int, root_dir: Path):
61+
paths = sorted(path for path in root_dir.iterdir()
62+
if re.match(r"^consolidated\.[0-9]+\.pth$", path.name))
63+
if len(paths) == 1: # no sharded checkpoints, return everything
64+
return torch.load(paths[0], map_location=torch.device("cpu"))
65+
66+
num_pth = len(paths)
67+
for i, ckpt_path in enumerate(tqdm(paths, desc="Merging llama")):
68+
llama_config = torch.load(ckpt_path, map_location=torch.device('cpu'))
69+
if i == 0:
70+
merged_ckpt = init_merged_ckpt(llama_config, num_pth=num_pth,
71+
emb_dim=scale2emb[f"{size}B"])
72+
else:
73+
for parameter_name, parameter in llama_config.items():
74+
short_name = parameter_name.split(".")[-2]
75+
if key_to_dim[short_name] == 0:
76+
size = parameter.shape[0]
77+
merged_param_shape = [ parameter.shape[0] * num_pth, parameter.shape[1] ]
78+
merged_ckpt[parameter_name][size * i : size * (i + 1), :] = parameter
79+
del parameter
80+
if key_to_dim[short_name] == -1:
81+
size = parameter.shape[-1]
82+
merged_param_shape = [ parameter.shape[0], parameter.shape[1] * num_pth]
83+
merged_ckpt[parameter_name][:, size * i : size * (i + 1)] = parameter
84+
del parameter
85+
del llama_config
86+
return merged_ckpt
87+
88+
89+
def merge_hf_llama(size: int, version: int, cache_dir: Optional[Path] = None, model_path=None, tokenizer_len=32000):
90+
assert version == 2, "Only llama v2 available using huggingface"
91+
print(cache_dir)
92+
model = LlamaForCausalLM.from_pretrained(cache_dir, cache_dir=cache_dir, local_files_only=True, use_safetensors=False)
93+
# resize token embeddings size according saved tokenizer for model extend token size.
94+
# model.resize_token_embeddings(tokenizer_len)
95+
weights = model.state_dict()
96+
weights["tok_embeddings.weight"] = weights.pop("model.embed_tokens.weight")
97+
weights["norm.weight"] = weights.pop("model.norm.weight")
98+
weights["output.weight"] = weights.pop("lm_head.weight")
99+
for key in list(weights.keys()):
100+
if rmatch := re.match(r"^model\.(layers\.[0-9]+\.)(.+)(\.weight)$", key):
101+
new_key = {
102+
"self_attn.q_proj": "attention.wq",
103+
"self_attn.k_proj": "attention.wk",
104+
"self_attn.v_proj": "attention.wv",
105+
"self_attn.o_proj": "attention.wo",
106+
"mlp.gate_proj": "feed_forward.w1",
107+
"mlp.down_proj": "feed_forward.w2",
108+
"mlp.up_proj": "feed_forward.w3",
109+
"input_layernorm": "attention_norm",
110+
"post_attention_layernorm": "ffn_norm"
111+
}[rmatch.group(2)]
112+
weights[rmatch.group(1) + new_key + rmatch.group(3)] = weights.pop(key)
113+
return weights
114+
115+
116+
def merge_llama(size: int, version: int, root_dir: Optional[Path] = None, tokenizer_len: Optional[int] = 32000):
117+
if root_dir is not None and (root_dir/"consolidated.00.pth").exists():
118+
return merge_meta_llama(size, root_dir), "meta"
119+
print(f"Weights at {root_dir} do not look like a meta checkpoint, assuming "
120+
"huggingface cache_dir instead")
121+
return merge_hf_llama(size, version, root_dir, tokenizer_len), "hf"
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import re
2+
import sys
3+
import os
4+
import shutil
5+
from pathlib import Path
6+
from argparse import ArgumentParser
7+
8+
import torch
9+
from tqdm.auto import tqdm
10+
11+
12+
def permute_qkv(qkv_w: torch.Tensor, dim: int, n_heads: int,
13+
n_heads_kv: int, revert: bool = False) -> torch.Tensor:
14+
15+
def permute(x):
16+
if revert:
17+
return x.view(head_dim//2, 2, dim).transpose(0, 1).reshape(head_dim, dim)
18+
return x.view(2, head_dim//2, dim).transpose(0, 1).reshape(head_dim, dim)
19+
20+
head_dim = dim//n_heads
21+
n_qs_per_kv = n_heads//n_heads_kv
22+
n_groups = qkv_w.size(0)//head_dim//(n_qs_per_kv + 2)
23+
groups = torch.chunk(qkv_w, n_groups, dim=0)
24+
new = []
25+
for group in groups:
26+
*qs, k, v = torch.split(group, head_dim, dim=0)
27+
assert len(qs) == n_qs_per_kv, f"{len(qs)}, {n_qs_per_kv}"
28+
new += list(map(permute, qs)) + [permute(k), v]
29+
return torch.cat(new, dim=0)
30+
31+
32+
def update_checkpoint(input_dir: Path, output_dir: Path, overwrite_ok: bool = False):
33+
# make sure megatron is importable
34+
sys.path.append(os.path.abspath(
35+
os.path.join(os.path.dirname(__file__),
36+
os.path.pardir)))
37+
38+
39+
# prepare output dir
40+
if output_dir.exists():
41+
if not overwrite_ok:
42+
raise FileExistsError(f"Output directory {output_dir} already exists")
43+
print(f"Removing {output_dir}")
44+
shutil.rmtree(output_dir)
45+
output_dir.mkdir(exist_ok=True)
46+
47+
# determine realease
48+
with open(input_dir/"latest_checkpointed_iteration.txt") as f:
49+
it = f.read()
50+
print("Updating weights of iteration", it)
51+
with open(output_dir/"latest_checkpointed_iteration.txt", "w+") as f:
52+
f.write(it)
53+
(output_dir/it).mkdir()
54+
55+
# convert weights
56+
for fname in tqdm(list((input_dir/it).iterdir())):
57+
checkpoint = torch.load(fname/"model_optim_rng.pt")
58+
args = checkpoint["args"]
59+
args = (args.hidden_size, args.num_attention_heads,
60+
args.num_attention_heads_kv)
61+
if "transformer" in checkpoint["model"]["language_model"]:
62+
key = "transformer"
63+
attn_key = "attention"
64+
else:
65+
key = "encoder"
66+
attn_key = "self_attention"
67+
states = checkpoint["model"]["language_model"][key]
68+
for name, weight in states.items():
69+
if re.match(rf"^layers\.[0-9]+\.{attn_key}\.query_key_value\.weight$", name):
70+
states[name] = permute_qkv(weight, *args)
71+
(output_dir/it/fname.stem).mkdir()
72+
torch.save(checkpoint, output_dir/it/fname.stem/"model_optim_rng.pt")
73+
74+
75+
if __name__ == "__main__":
76+
parser = ArgumentParser()
77+
parser.add_argument("--input-dir", type=Path)
78+
parser.add_argument("--output-dir", type=Path)
79+
parser.add_argument("--overwrite-ok", action="store_true")
80+
args = parser.parse_args()
81+
update_checkpoint(args.input_dir, args.output_dir, args.overwrite_ok)
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
import os
2+
import sys
3+
import shutil
4+
from pathlib import Path
5+
from typing import Optional
6+
from argparse import ArgumentParser, Namespace
7+
8+
import torch
9+
from tqdm.auto import trange
10+
from transformers import AutoModelForCausalLM, LlamaTokenizer
11+
12+
from permute_qkv import permute_qkv
13+
from merge_llama import merge_llama
14+
from transformers import AutoTokenizer
15+
16+
llama_s2layer = {7: 32, 13: 40, 30: 60, 65: 80, 70: 80}
17+
llama_s2heads = {7: 32, 13: 40, 30: 52, 65: 64, 70: 64}
18+
llama_s2dense = {7: 11008, 13: 13824, 30: 17920, 65: 22016,
19+
70: 28672} # should be (2/3)*4*d, but it isn't exaclty that
20+
llama_s2hidden = {7: 4096, 13: 5120, 30: 6656, 65: 8192, 70: 8192}
21+
22+
23+
def llama_to_megatron(weights: dict, size: int, source: str = "meta",
24+
version: int = 1) -> dict:
25+
def permute(qkv_w):
26+
if source == "hf":
27+
return permute_qkv(qkv_w, hidden, n_heads, n_kv_heads)
28+
return qkv_w
29+
30+
def rearrange_qkv(wq, wk, wv):
31+
wq = torch.split(wq, n_hidden_per_head, dim=0)
32+
wk = torch.split(wk, n_hidden_per_head, dim=0)
33+
wv = torch.split(wv, n_hidden_per_head, dim=0)
34+
assert len(wq) == n_heads
35+
assert len(wk) == n_kv_heads
36+
assert len(wv) == n_kv_heads
37+
n_qs_per_kv = n_heads//n_kv_heads
38+
w_qkv = []
39+
for i in range(n_kv_heads):
40+
w_qkv += [wq[i*n_qs_per_kv + j] for j in range(n_qs_per_kv)]
41+
w_qkv += [wk[i], wv[i]]
42+
return permute(torch.concat(w_qkv))
43+
44+
# config
45+
n_layer = llama_s2layer[size]
46+
hidden = llama_s2hidden[size]
47+
n_heads = llama_s2heads[size]
48+
n_hidden_per_head = hidden//n_heads
49+
n_kv_heads = n_heads if version == 1 or size <= 13 else 8
50+
51+
# weights independent of layers
52+
embedding = {"word_embeddings": {"weight": weights["tok_embeddings.weight"]}}
53+
transformer = {"final_layernorm.weight": weights["norm.weight"]}
54+
lm_head = weights["output.weight"]
55+
# get all the other weights
56+
for layer in trange(n_layer, desc="Converting weights"):
57+
prefix = f"layers.{layer}"
58+
# identical weights
59+
transformer[f"{prefix}.attention.dense.weight"] = \
60+
weights[f"{prefix}.attention.wo.weight"]
61+
transformer[f"{prefix}.post_attention_layernorm.weight"] = \
62+
weights[f"{prefix}.ffn_norm.weight"]
63+
transformer[f"{prefix}.input_layernorm.weight"] = \
64+
weights[f"{prefix}.attention_norm.weight"]
65+
transformer[f"{prefix}.mlp.dense_4h_to_h.weight"] = \
66+
weights[f"{prefix}.feed_forward.w2.weight"]
67+
# concatenate up, gate mlp weights
68+
transformer[f"{prefix}.mlp.dense_h_to_4h.weight"] = torch.concat([
69+
weights[f"{prefix}.feed_forward.w3.weight"],
70+
weights[f"{prefix}.feed_forward.w1.weight"]
71+
])
72+
# finally, qkv requires serious manipulation to get right
73+
transformer[f"{prefix}.attention.query_key_value.weight"] = rearrange_qkv(
74+
weights[f"{prefix}.attention.wq.weight"],
75+
weights[f"{prefix}.attention.wk.weight"],
76+
weights[f"{prefix}.attention.wv.weight"]
77+
)
78+
79+
# release references to original weights (free mem)
80+
del weights[f"{prefix}.feed_forward.w3.weight"]
81+
del weights[f"{prefix}.feed_forward.w1.weight"]
82+
del weights[f"{prefix}.attention.wq.weight"]
83+
del weights[f"{prefix}.attention.wk.weight"]
84+
del weights[f"{prefix}.attention.wv.weight"]
85+
86+
return {"embedding": embedding, "encoder": transformer,
87+
"lm_head": lm_head}
88+
89+
def main(model_name: str = "llama2", size: int = 7, out: Optional[Path] = None,
90+
cache_dir: Optional[Path] = None, megatron_path: Optional[Path] = None, padded_vocab_size: Optional[int] = 32000):
91+
92+
# get weights from or specified directory
93+
print("Getting llama...")
94+
version = 2 if "2" in model_name else 1
95+
hf_weights, llama_source = merge_llama(size, version, cache_dir, padded_vocab_size)
96+
97+
# convert state dict to be megatron-compatible
98+
megatron_weights = llama_to_megatron(hf_weights, size, llama_source,
99+
version=1 if model_name == "llama" else 2)
100+
101+
# set args
102+
# llama1, llama2
103+
args = {"num_layers": llama_s2layer[size],
104+
"hidden_size": llama_s2hidden[size],
105+
"num_attention_heads": llama_s2heads[size],
106+
"ffn_hidden_size": llama_s2dense[size],
107+
"num_key_value_heads": llama_s2heads[size],
108+
"parallel_attn": False,
109+
"make_vocab_size_divisible_by": 1,
110+
"glu_activation": "swiglu",
111+
# llama args
112+
"padded_vocab_size": padded_vocab_size,
113+
"use_rms_norm": True,
114+
"tie_embed_logits": False,
115+
"tokenizer_type": "GPTSentencePieceTokenizer",
116+
"no-query-key-layer-scaling": True,
117+
"attention-dropout": 0,
118+
"hidden-dropout": 0,
119+
"use-rotary-position-embeddings": True,
120+
"untie-embeddings-and-output-weights": True,
121+
"swiglu": True,
122+
"normalization": "rmsnorm",
123+
"disable-bias-linear": True,
124+
"add_position_embedding": False,
125+
"add_bias_linear": False,
126+
}
127+
if model_name == "llama":
128+
args.update({"max_position_embeddings": 2048, "seq_length": 2048,
129+
"layernorm_epsilon": 1e-6})
130+
else: # llama2
131+
args.update({"max_position_embeddings": 2048, "seq_length": 2048,
132+
"layernorm_epsilon": 1e-5})
133+
if size >= 34:
134+
args.update({"num_attention_heads_kv": 8})
135+
136+
args.update({
137+
"tensor_model_parallel_size": 1,
138+
"pipeline_model_parallel_size": 1,
139+
"iteration": "release",
140+
"bias_gelu_fusion": False,
141+
"bias_droput_fusion": False,
142+
})
143+
144+
# save converted weights in specified out
145+
(out/"release"/"mp_rank_00").mkdir(parents=True)
146+
with open(out/"latest_checkpointed_iteration.txt", "w+") as f:
147+
f.write("release")
148+
final_dict = {"iteration": "release", "model": {"language_model": megatron_weights},
149+
"checkpoint_version": 3.0, "args": Namespace(**args)}
150+
torch.save(final_dict, out/"release"/"mp_rank_00"/"model_optim_rng.pt")
151+
print("Saved weights in", out)
152+
153+
if model_name == "llama2" and llama_source == "hf":
154+
tokenizer = LlamaTokenizer.from_pretrained(
155+
cache_dir, cache_dir=cache_dir, local_files_only=True,
156+
)
157+
token_path = out/"tokenizer.model"
158+
vocab_file = tokenizer.vocab_file
159+
shutil.copy(vocab_file, token_path)
160+
print("Saved tokenizer.model in", token_path)
161+
print("Done")
162+
163+
if __name__ == "__main__":
164+
parser = ArgumentParser(description="Convert Huggingface falcon weights to "
165+
"megatron-compatible weights")
166+
parser.add_argument("model", choices={"falcon", "llama", "llama2"})
167+
parser.add_argument("--size", default=7, choices={7, 13, 30, 34, 40, 65, 70}, type=int,
168+
help="The size of the model")
169+
parser.add_argument("--out", type=Path,
170+
help="Directory to store the megatron weights (as checkpoint)")
171+
parser.add_argument("--cache-dir", type=Path,
172+
help=("Directory to store the huggingface weights, or "
173+
"in case of the llama model, where to look for "
174+
"the consolidated.xx.pth"))
175+
parser.add_argument("--megatron-path", type=Path,
176+
help="Path where to find megatron code")
177+
parser.add_argument("--tokenizer-size", type=int, help="Directory to store the megatron weights (as checkpoint)", default=None)
178+
args = parser.parse_args()
179+
180+
# small arg verification
181+
if args.model == "llama":
182+
assert args.size in {7, 13, 30, 65}
183+
else:
184+
assert args.size in {7, 13, 70}
185+
186+
main(args.model, args.size, args.out, args.cache_dir, args.megatron_path, args.tokenizer_size)

0 commit comments

Comments
 (0)