|
| 1 | +import os |
| 2 | +import sys |
| 3 | +import shutil |
| 4 | +from pathlib import Path |
| 5 | +from typing import Optional |
| 6 | +from argparse import ArgumentParser, Namespace |
| 7 | + |
| 8 | +import torch |
| 9 | +from tqdm.auto import trange |
| 10 | +from transformers import AutoModelForCausalLM, LlamaTokenizer |
| 11 | + |
| 12 | +from permute_qkv import permute_qkv |
| 13 | +from merge_llama import merge_llama |
| 14 | +from transformers import AutoTokenizer |
| 15 | + |
| 16 | +llama_s2layer = {7: 32, 13: 40, 30: 60, 65: 80, 70: 80} |
| 17 | +llama_s2heads = {7: 32, 13: 40, 30: 52, 65: 64, 70: 64} |
| 18 | +llama_s2dense = {7: 11008, 13: 13824, 30: 17920, 65: 22016, |
| 19 | + 70: 28672} # should be (2/3)*4*d, but it isn't exaclty that |
| 20 | +llama_s2hidden = {7: 4096, 13: 5120, 30: 6656, 65: 8192, 70: 8192} |
| 21 | + |
| 22 | + |
| 23 | +def llama_to_megatron(weights: dict, size: int, source: str = "meta", |
| 24 | + version: int = 1) -> dict: |
| 25 | + def permute(qkv_w): |
| 26 | + if source == "hf": |
| 27 | + return permute_qkv(qkv_w, hidden, n_heads, n_kv_heads) |
| 28 | + return qkv_w |
| 29 | + |
| 30 | + def rearrange_qkv(wq, wk, wv): |
| 31 | + wq = torch.split(wq, n_hidden_per_head, dim=0) |
| 32 | + wk = torch.split(wk, n_hidden_per_head, dim=0) |
| 33 | + wv = torch.split(wv, n_hidden_per_head, dim=0) |
| 34 | + assert len(wq) == n_heads |
| 35 | + assert len(wk) == n_kv_heads |
| 36 | + assert len(wv) == n_kv_heads |
| 37 | + n_qs_per_kv = n_heads//n_kv_heads |
| 38 | + w_qkv = [] |
| 39 | + for i in range(n_kv_heads): |
| 40 | + w_qkv += [wq[i*n_qs_per_kv + j] for j in range(n_qs_per_kv)] |
| 41 | + w_qkv += [wk[i], wv[i]] |
| 42 | + return permute(torch.concat(w_qkv)) |
| 43 | + |
| 44 | + # config |
| 45 | + n_layer = llama_s2layer[size] |
| 46 | + hidden = llama_s2hidden[size] |
| 47 | + n_heads = llama_s2heads[size] |
| 48 | + n_hidden_per_head = hidden//n_heads |
| 49 | + n_kv_heads = n_heads if version == 1 or size <= 13 else 8 |
| 50 | + |
| 51 | + # weights independent of layers |
| 52 | + embedding = {"word_embeddings": {"weight": weights["tok_embeddings.weight"]}} |
| 53 | + transformer = {"final_layernorm.weight": weights["norm.weight"]} |
| 54 | + lm_head = weights["output.weight"] |
| 55 | + # get all the other weights |
| 56 | + for layer in trange(n_layer, desc="Converting weights"): |
| 57 | + prefix = f"layers.{layer}" |
| 58 | + # identical weights |
| 59 | + transformer[f"{prefix}.attention.dense.weight"] = \ |
| 60 | + weights[f"{prefix}.attention.wo.weight"] |
| 61 | + transformer[f"{prefix}.post_attention_layernorm.weight"] = \ |
| 62 | + weights[f"{prefix}.ffn_norm.weight"] |
| 63 | + transformer[f"{prefix}.input_layernorm.weight"] = \ |
| 64 | + weights[f"{prefix}.attention_norm.weight"] |
| 65 | + transformer[f"{prefix}.mlp.dense_4h_to_h.weight"] = \ |
| 66 | + weights[f"{prefix}.feed_forward.w2.weight"] |
| 67 | + # concatenate up, gate mlp weights |
| 68 | + transformer[f"{prefix}.mlp.dense_h_to_4h.weight"] = torch.concat([ |
| 69 | + weights[f"{prefix}.feed_forward.w3.weight"], |
| 70 | + weights[f"{prefix}.feed_forward.w1.weight"] |
| 71 | + ]) |
| 72 | + # finally, qkv requires serious manipulation to get right |
| 73 | + transformer[f"{prefix}.attention.query_key_value.weight"] = rearrange_qkv( |
| 74 | + weights[f"{prefix}.attention.wq.weight"], |
| 75 | + weights[f"{prefix}.attention.wk.weight"], |
| 76 | + weights[f"{prefix}.attention.wv.weight"] |
| 77 | + ) |
| 78 | + |
| 79 | + # release references to original weights (free mem) |
| 80 | + del weights[f"{prefix}.feed_forward.w3.weight"] |
| 81 | + del weights[f"{prefix}.feed_forward.w1.weight"] |
| 82 | + del weights[f"{prefix}.attention.wq.weight"] |
| 83 | + del weights[f"{prefix}.attention.wk.weight"] |
| 84 | + del weights[f"{prefix}.attention.wv.weight"] |
| 85 | + |
| 86 | + return {"embedding": embedding, "encoder": transformer, |
| 87 | + "lm_head": lm_head} |
| 88 | + |
| 89 | +def main(model_name: str = "llama2", size: int = 7, out: Optional[Path] = None, |
| 90 | + cache_dir: Optional[Path] = None, megatron_path: Optional[Path] = None, padded_vocab_size: Optional[int] = 32000): |
| 91 | + |
| 92 | + # get weights from or specified directory |
| 93 | + print("Getting llama...") |
| 94 | + version = 2 if "2" in model_name else 1 |
| 95 | + hf_weights, llama_source = merge_llama(size, version, cache_dir, padded_vocab_size) |
| 96 | + |
| 97 | + # convert state dict to be megatron-compatible |
| 98 | + megatron_weights = llama_to_megatron(hf_weights, size, llama_source, |
| 99 | + version=1 if model_name == "llama" else 2) |
| 100 | + |
| 101 | + # set args |
| 102 | + # llama1, llama2 |
| 103 | + args = {"num_layers": llama_s2layer[size], |
| 104 | + "hidden_size": llama_s2hidden[size], |
| 105 | + "num_attention_heads": llama_s2heads[size], |
| 106 | + "ffn_hidden_size": llama_s2dense[size], |
| 107 | + "num_key_value_heads": llama_s2heads[size], |
| 108 | + "parallel_attn": False, |
| 109 | + "make_vocab_size_divisible_by": 1, |
| 110 | + "glu_activation": "swiglu", |
| 111 | + # llama args |
| 112 | + "padded_vocab_size": padded_vocab_size, |
| 113 | + "use_rms_norm": True, |
| 114 | + "tie_embed_logits": False, |
| 115 | + "tokenizer_type": "GPTSentencePieceTokenizer", |
| 116 | + "no-query-key-layer-scaling": True, |
| 117 | + "attention-dropout": 0, |
| 118 | + "hidden-dropout": 0, |
| 119 | + "use-rotary-position-embeddings": True, |
| 120 | + "untie-embeddings-and-output-weights": True, |
| 121 | + "swiglu": True, |
| 122 | + "normalization": "rmsnorm", |
| 123 | + "disable-bias-linear": True, |
| 124 | + "add_position_embedding": False, |
| 125 | + "add_bias_linear": False, |
| 126 | + } |
| 127 | + if model_name == "llama": |
| 128 | + args.update({"max_position_embeddings": 2048, "seq_length": 2048, |
| 129 | + "layernorm_epsilon": 1e-6}) |
| 130 | + else: # llama2 |
| 131 | + args.update({"max_position_embeddings": 2048, "seq_length": 2048, |
| 132 | + "layernorm_epsilon": 1e-5}) |
| 133 | + if size >= 34: |
| 134 | + args.update({"num_attention_heads_kv": 8}) |
| 135 | + |
| 136 | + args.update({ |
| 137 | + "tensor_model_parallel_size": 1, |
| 138 | + "pipeline_model_parallel_size": 1, |
| 139 | + "iteration": "release", |
| 140 | + "bias_gelu_fusion": False, |
| 141 | + "bias_droput_fusion": False, |
| 142 | + }) |
| 143 | + |
| 144 | + # save converted weights in specified out |
| 145 | + (out/"release"/"mp_rank_00").mkdir(parents=True) |
| 146 | + with open(out/"latest_checkpointed_iteration.txt", "w+") as f: |
| 147 | + f.write("release") |
| 148 | + final_dict = {"iteration": "release", "model": {"language_model": megatron_weights}, |
| 149 | + "checkpoint_version": 3.0, "args": Namespace(**args)} |
| 150 | + torch.save(final_dict, out/"release"/"mp_rank_00"/"model_optim_rng.pt") |
| 151 | + print("Saved weights in", out) |
| 152 | + |
| 153 | + if model_name == "llama2" and llama_source == "hf": |
| 154 | + tokenizer = LlamaTokenizer.from_pretrained( |
| 155 | + cache_dir, cache_dir=cache_dir, local_files_only=True, |
| 156 | + ) |
| 157 | + token_path = out/"tokenizer.model" |
| 158 | + vocab_file = tokenizer.vocab_file |
| 159 | + shutil.copy(vocab_file, token_path) |
| 160 | + print("Saved tokenizer.model in", token_path) |
| 161 | + print("Done") |
| 162 | + |
| 163 | +if __name__ == "__main__": |
| 164 | + parser = ArgumentParser(description="Convert Huggingface falcon weights to " |
| 165 | + "megatron-compatible weights") |
| 166 | + parser.add_argument("model", choices={"falcon", "llama", "llama2"}) |
| 167 | + parser.add_argument("--size", default=7, choices={7, 13, 30, 34, 40, 65, 70}, type=int, |
| 168 | + help="The size of the model") |
| 169 | + parser.add_argument("--out", type=Path, |
| 170 | + help="Directory to store the megatron weights (as checkpoint)") |
| 171 | + parser.add_argument("--cache-dir", type=Path, |
| 172 | + help=("Directory to store the huggingface weights, or " |
| 173 | + "in case of the llama model, where to look for " |
| 174 | + "the consolidated.xx.pth")) |
| 175 | + parser.add_argument("--megatron-path", type=Path, |
| 176 | + help="Path where to find megatron code") |
| 177 | + parser.add_argument("--tokenizer-size", type=int, help="Directory to store the megatron weights (as checkpoint)", default=None) |
| 178 | + args = parser.parse_args() |
| 179 | + |
| 180 | + # small arg verification |
| 181 | + if args.model == "llama": |
| 182 | + assert args.size in {7, 13, 30, 65} |
| 183 | + else: |
| 184 | + assert args.size in {7, 13, 70} |
| 185 | + |
| 186 | + main(args.model, args.size, args.out, args.cache_dir, args.megatron_path, args.tokenizer_size) |
0 commit comments