diff --git a/scripts/convert_checkpoint_hf.py b/scripts/convert_checkpoint_hf.py new file mode 100644 index 00000000..15107a3c --- /dev/null +++ b/scripts/convert_checkpoint_hf.py @@ -0,0 +1,242 @@ +# Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import gc +import json +import os +import shutil +import warnings + +import torch + +from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer + + +try: + from transformers import LlamaTokenizerFast +except ImportError as e: + warnings.warn(e) + warnings.warn( + "The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion" + ) + LlamaTokenizerFast = None + +""" +Sample usage: + +``` +python scripts/convert_checkpoint_hf.py \ + --input_dir /path/to/training_checkpoint --model_size 7B --output_dir /output/path +``` + +Thereafter, models can be loaded via: + +```py +from transformers import LlamaForCausalLM, LlamaTokenizer + +model = LlamaForCausalLM.from_pretrained("/output/path") +tokenizer = LlamaTokenizer.from_pretrained("/output/path") +``` + +Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions +come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM). +""" + +INTERMEDIATE_SIZE_MAP = { + "7B": 11008, + "13B": 13824, + "30B": 17920, + "65B": 22016, + "70B": 28672, +} +NUM_SHARDS = { + "7B": 1, + "7Bf": 1, + "13B": 2, + "13Bf": 2, + "30B": 4, + "65B": 8, + "70B": 8, + "70Bf": 8, +} + +PARAMS = { + "7B": {"dim": 4096, "multiple_of": 256, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-06, "vocab_size": -1}, + "13B": {"dim": 5120, "multiple_of": 256, "n_heads": 40, "n_layers": 40, "norm_eps": 1e-06, "vocab_size": -1}, + "30B": {"dim": 6656, "multiple_of": 256, "n_heads": 52, "n_layers": 60, "norm_eps": 1e-06, "vocab_size": -1}, + "65B": {"dim": 8192, "multiple_of": 256, "n_heads": 64, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": -1}, +} + + +def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256): + return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of) + + +def read_json(path): + with open(path, "r") as f: + return json.load(f) + + +def write_json(text, path): + with open(path, "w") as f: + json.dump(text, f) + + +def write_model(model_path, input_path, model_size, safe_serialization=True): + os.makedirs(model_path, exist_ok=True) + tmp_model_path = os.path.join(model_path, "tmp") + os.makedirs(tmp_model_path, exist_ok=True) + + params = PARAMS[model_size] + num_shards = NUM_SHARDS[model_size] + n_layers = params["n_layers"] + n_heads = params["n_heads"] + n_heads_per_shard = n_heads // num_shards + dim = params["dim"] + dims_per_head = dim // n_heads + base = 10000.0 + inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)) + + if "n_kv_heads" in params: + num_key_value_heads = params["n_kv_heads"] # for GQA / MQA + num_local_key_value_heads = n_heads_per_shard // num_key_value_heads + key_value_dim = dim // num_key_value_heads + else: # compatibility with other checkpoints + num_key_value_heads = n_heads + num_local_key_value_heads = n_heads_per_shard + key_value_dim = dim + + # permute for sliced rotary + def permute(w, n_heads=n_heads, dim1=dim, dim2=dim): + return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2) + + print(f"Fetching all parameters from the checkpoint at {input_path}.") + + # Not sharded + # (The sharded implementation would also work, but this is simpler.) + # loaded = torch.load(os.path.join(input_path, "consolidated.00.pth"), map_location="cpu") + loaded = torch.load(input_path, map_location="cpu") + + + param_count = 0 + index_dict = {"weight_map": {}} + for layer_i in range(n_layers): + filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin" + + # Unsharded + # FYI, lots of changes made here... + state_dict = { + f"model.layers.{layer_i}.self_attn.q_proj.weight": permute( + loaded[f"transformer.h.{layer_i}.attn.c_attn.weight"][0:dim] + ), + f"model.layers.{layer_i}.self_attn.k_proj.weight": permute( + loaded[f"transformer.h.{layer_i}.attn.c_attn.weight"][dim:2*dim] + ), + f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"transformer.h.{layer_i}.attn.c_attn.weight"][2*dim:3*dim], + f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"transformer.h.{layer_i}.attn.c_proj.weight"], # + f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"transformer.h.{layer_i}.mlp.c_fc1.weight"], + f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"transformer.h.{layer_i}.mlp.c_proj.weight"], + f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"transformer.h.{layer_i}.mlp.c_fc2.weight"], + f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"transformer.h.{layer_i}.rms_1.scale"], + f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[f"transformer.h.{layer_i}.rms_2.scale"], + } + + state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq #what to do with this? + for k, v in state_dict.items(): + index_dict["weight_map"][k] = filename + param_count += v.numel() + torch.save(state_dict, os.path.join(tmp_model_path, filename)) + + filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin" + + # Unsharded + #more changes here + state_dict = { + "model.embed_tokens.weight": loaded["transformer.wte.weight"], + "model.norm.weight": loaded["transformer.ln_f.scale"], + "lm_head.weight": loaded["lm_head.weight"], + } + + for k, v in state_dict.items(): + index_dict["weight_map"][k] = filename + param_count += v.numel() + torch.save(state_dict, os.path.join(tmp_model_path, filename)) + + # Write configs + index_dict["metadata"] = {"total_size": param_count * 2} + write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json")) + ffn_dim_multiplier = params["ffn_dim_multiplier"] if "ffn_dim_multiplier" in params else 1 + multiple_of = params["multiple_of"] if "multiple_of" in params else 256 + config = LlamaConfig( + hidden_size=dim, + intermediate_size=compute_intermediate_size(dim, ffn_dim_multiplier, multiple_of), + num_attention_heads=params["n_heads"], + num_hidden_layers=params["n_layers"], + rms_norm_eps=params["norm_eps"], + num_key_value_heads=num_key_value_heads, + ) + config.save_pretrained(tmp_model_path) + + # Make space so we can load the model properly now. + del state_dict + del loaded + gc.collect() + + print("Loading the checkpoint in a Llama model.") + model = LlamaForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) + # Avoid saving this as part of the config. + del model.config._name_or_path + + print("Saving in the Transformers format.") + model.save_pretrained(model_path, safe_serialization=safe_serialization) + shutil.rmtree(tmp_model_path) + + +def write_tokenizer(tokenizer_path, input_tokenizer_path): + # Initialize the tokenizer based on the `spm` model + tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast + print(f"Saving a {tokenizer_class.__name__} to {tokenizer_path}.") + tokenizer = tokenizer_class(input_tokenizer_path) + tokenizer.save_pretrained(tokenizer_path) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--input_dir", + help="Location of LLaMA weights, which contains tokenizer.model and model folders", + ) + parser.add_argument( + "--model_size", + choices=["7B", "7Bf", "13B", "13Bf", "30B", "65B", "70B", "70Bf", "tokenizer_only"], + ) + parser.add_argument( + "--output_dir", + help="Location to write HF model and tokenizer", + ) + parser.add_argument("--safe_serialization", type=bool, help="Whether or not to save using `safetensors`.") + args = parser.parse_args() + if args.model_size != "tokenizer_only": + write_model( + model_path=args.output_dir, + input_path=args.input_dir,#os.path.join(args.input_dir, args.model_size), + model_size=args.model_size, + safe_serialization=args.safe_serialization, + ) + # spm_path = os.path.join(args.input_dir, "tokenizer.model") + # write_tokenizer(args.output_dir, spm_path) + + +if __name__ == "__main__": + main() \ No newline at end of file