diff --git a/accessory/tools/convert_weights_to_hf.py b/accessory/tools/convert_weights_to_hf.py index a32a2631..c90becc2 100644 --- a/accessory/tools/convert_weights_to_hf.py +++ b/accessory/tools/convert_weights_to_hf.py @@ -19,23 +19,41 @@ Example usage:: - # The folders to prepare: - # - # /path/to/llama-2-70b: Path to the original LLaMA-2-70B weights by Meta. - # /path/to/finetune/sg/dialog_sharegpt_70b: ShareGPT finetuned delta - # weights downloaded from our repo. - # /path/to/llama/tokenizer.model: Tokenizer model file released by Meta. - # /path/to/llama2_accessory_github_repo: Path to the cloned Github repo. - # - # Then, run in Bash: - - $ cd /path/to/llama2_accessory_github_repo - $ python -m tools.convert_weights_to_hf \ - --src_weights_path /path/to/llama-2-70b \ - /path/to/finetune/sg/dialog_sharegpt_70b \ - --src_config_path /path/to/llama-2-70b/params.json \ - --tokenizer_path /path/to/llama/tokenizer.model \ - --dst_weights_path /path/to/llama-2-70b-hf-sharegpt + If you are working with LLAMA models: + # The folders to prepare: + # + # /path/to/llama-2-70b: Path to the original LLaMA-2-70B weights by Meta. + # /path/to/finetune/sg/dialog_sharegpt_70b: ShareGPT finetuned delta + # weights downloaded from our repo. + # /path/to/llama/tokenizer.model: Tokenizer model file released by Meta. + # /path/to/llama2_accessory_github_repo: Path to the cloned Github repo. + # + # Then, run in Bash: + + $ cd /path/to/llama2_accessory_github_repo + $ python -m tools.convert_weights_to_hf \ + --src_weights_path /path/to/llama-2-70b \ + /path/to/finetune/sg/dialog_sharegpt_70b \ + --src_config_path /path/to/llama-2-70b/params.json \ + --tokenizer_path /path/to/llama/tokenizer.model \ + --dst_weights_path /path/to/llama-2-70b-hf-sharegpt + + If you are working with Mixtral MoE models: + # The folders to prepare: + # + # /mnt/bn/codegeniusgen1/ckpt/MoE-Mixtral-7B-8Expert/converted: Path to the + # converted mixtral MoE ckpt. (base implementation) + # /mnt/bn/codegeniusgen1/ckpt/MoE-Mixtral-7B-8Expert/converted/tokenizer.model: Tokenizer model file released by Meta. + # /path/to/llama2_accessory_github_repo: Path to the cloned Github repo. + # + # Then, run in Bash: + + $ cd /path/to/llama2_accessory_github_repo + $ python -m tools.convert_weights_to_hf --mixtral \ + --src_weights_path /mnt/bn/codegeniusgen1/ckpt/MoE-Mixtral-7B-8Expert/converted \ + --src_config_path /mnt/bn/codegeniusgen1/ckpt/MoE-Mixtral-7B-8Expert/converted/config.json \ + --tokenizer_path /mnt/bn/codegeniusgen1/ckpt/MoE-Mixtral-7B-8Expert/converted/path/to/llama/tokenizer.model \ + --dst_weights_path /path/to/MoE-Mixtral-7B-8Expert-hf # If the model to convert contains unknown parameters (e.g., converting a # multi-modal model to huggingface LLaMA which is language-only), add @@ -59,7 +77,6 @@ to the latest version of ``transformers`` (e.g., >= 4.32.0) should get rid of the warning. """ - import argparse import json import os @@ -92,31 +109,57 @@ def load_and_merge_tensor_parallel_weights( src_weights_path: List[str], torch_dtype: torch.dtype, ignore_unknown_keys: bool = False, + model_type: str = 'llama' ) -> Dict[str, torch.Tensor]: # Manually specify merge dim for each weight name pattern because: # 1. To avoid creating a model (and then infer the merge dim) to save # memory. # 2. Only weights actually supported by HuggingFace are listed (e.g., # biases are not supported now) so there won't be a lot of corner cases. - pattern_to_merge_dim = ( - ("^llma.tok_embeddings.weight$", 1), - ("^llma.layers.(\d+).attention.wq.weight$", 0), - ("^llma.layers.(\d+).attention.wk.weight$", 0), - ("^llma.layers.(\d+).attention.wv.weight$", 0), - ("^llma.layers.(\d+).attention.wo.weight$", 1), - ("^llma.layers.(\d+).attention_norm.weight", -1), - ("^llma.layers.(\d+).feed_forward.w1.weight$", 0), - ("^llma.layers.(\d+).feed_forward.w2.weight$", 1), - ("^llma.layers.(\d+).feed_forward.w3.weight$", 0), - ("^llma.layers.(\d+).ffn_norm.weight", -1), - ("^llma.output.weight$", 0), - ("^llma.norm.weight$", -1), - ("^llma.rope.freqs$", -1), - ) + if model_type == 'llama': + pattern_to_merge_dim = ( + ("^llma.tok_embeddings.weight$", 1), + ("^llma.layers.(\d+).attention.wq.weight$", 0), + ("^llma.layers.(\d+).attention.wk.weight$", 0), + ("^llma.layers.(\d+).attention.wv.weight$", 0), + ("^llma.layers.(\d+).attention.wo.weight$", 1), + ("^llma.layers.(\d+).attention_norm.weight", -1), + ("^llma.layers.(\d+).feed_forward.w1.weight$", 0), + ("^llma.layers.(\d+).feed_forward.w2.weight$", 1), + ("^llma.layers.(\d+).feed_forward.w3.weight$", 0), + ("^llma.layers.(\d+).ffn_norm.weight", -1), + ("^llma.output.weight$", 0), + ("^llma.norm.weight$", -1), + ("^llma.rope.freqs$", -1), + ) + elif model_type == 'mixtral': + pattern_to_merge_dim = ( + ("^llma.tok_embeddings.weight$", 1), + ("^llma.layers.(\d+).attention.wq.weight$", 0), + ("^llma.layers.(\d+).attention.wk.weight$", 0), + ("^llma.layers.(\d+).attention.wv.weight$", 0), + ("^llma.layers.(\d+).attention.wo.weight$", 1), + ("^llma.layers.(\d+).attention_norm.weight", -1), + ("^llma.layers.(\d+).feed_forward.gate.weight$", -1), + ("^llma.layers.(\d+).feed_forward.experts.(\d+).w1.weight$", -1), # for base impl TODO: support sparse impl + ("^llma.layers.(\d+).feed_forward.experts.(\d+).w2.weight$", -1), # for base impl TODO: support sparse impl + ("^llma.layers.(\d+).feed_forward.experts.(\d+).w3.weight$", -1), # for base impl TODO: support sparse impl + ("^llma.layers.(\d+).ffn_norm.weight", -1), + ("^llma.output.weight$", 0), + ("^llma.norm.weight$", -1), + ) + else: + raise NotImplementedError(f"Unsupported model type {model_type}") pattern_to_merge_dim = tuple( (re.compile(pattern), dim) for pattern, dim in pattern_to_merge_dim ) + + # these tensors are distributed # TODO: support sparse impl + distributed_pattern = ["^llma.layers.(\d+).feed_forward.experts.(\d+).w1.weight$", + "^llma.layers.(\d+).feed_forward.experts.(\d+).w2.weight$", + "^llma.layers.(\d+).feed_forward.experts.(\d+).w3.weight$"] + merged_ckpt = {} ignored_keys = [] for i, path in enumerate(src_weights_path): @@ -154,13 +197,15 @@ def load_and_merge_tensor_parallel_weights( ) merged_ckpt[key] = torch.zeros(merged_size, dtype=init_dtype) + distributed_loading = any([re.compile(pattern).match(key) for pattern in distributed_pattern]) if key not in sharded_tensor_loaders: sharded_tensor_loaders[key] = ShardedTensorLoader( - merged_ckpt[key], mp_size, merge_dim, + merged_ckpt[key], + 1 if distributed_loading else mp_size, + merge_dim, mode="add" if format.endswith("_diff") else "set" ) - sharded_tensor_loaders[key].load_shard(shard_id, value) - + sharded_tensor_loaders[key].load_shard(0 if distributed_loading else shard_id, value) for key, value in sharded_tensor_loaders.items(): assert value.is_complete(), ( "A key is not loaded completely after going through all " @@ -181,28 +226,56 @@ def load_and_merge_tensor_parallel_weights( def convert_merged_ckpt_to_hf( merged_state_dict: Dict[str, torch.Tensor], params: Dict[str, Any], + model_type: str = 'llama' ) -> List[Dict[str, torch.Tensor]]: merged_state_dict = merged_state_dict.copy() num_layers = 0 while (f"llma.layers.{num_layers}.attention_norm.weight" in merged_state_dict): num_layers += 1 + if model_type == 'mixtral': + num_experts = 0 + while (f"llma.layers.0.feed_forward.experts.{num_experts}.w1.weight" + in merged_state_dict): + num_experts += 1 + else: + num_experts = None hf_ckpts = [] - if "llma.rope.freqs" in merged_state_dict: - del merged_state_dict["llma.rope.freqs"] + if model_type == 'llama': + if "llma.rope.freqs" in merged_state_dict: + del merged_state_dict["llma.rope.freqs"] for i in range(num_layers): hf_ckpt_shard = {} - for src_key, dst_key in [ - ("attention.wq.weight", "self_attn.q_proj.weight"), - ("attention.wk.weight", "self_attn.k_proj.weight"), - ("attention.wv.weight", "self_attn.v_proj.weight"), - ("attention.wo.weight", "self_attn.o_proj.weight"), - ("feed_forward.w3.weight", "mlp.up_proj.weight"), - ("feed_forward.w2.weight", "mlp.down_proj.weight"), - ("feed_forward.w1.weight", "mlp.gate_proj.weight"), - ("attention_norm.weight", "input_layernorm.weight"), - ("ffn_norm.weight", "post_attention_layernorm.weight"), - ]: + if model_type == 'llama': + src_dst_name_mapping = [ + ("attention.wq.weight", "self_attn.q_proj.weight"), + ("attention.wk.weight", "self_attn.k_proj.weight"), + ("attention.wv.weight", "self_attn.v_proj.weight"), + ("attention.wo.weight", "self_attn.o_proj.weight"), + ("feed_forward.w3.weight", "mlp.up_proj.weight"), + ("feed_forward.w2.weight", "mlp.down_proj.weight"), + ("feed_forward.w1.weight", "mlp.gate_proj.weight"), + ("attention_norm.weight", "input_layernorm.weight"), + ("ffn_norm.weight", "post_attention_layernorm.weight"), + ] + elif model_type == 'mixtral': + src_dst_name_mapping = [ + ("attention.wq.weight", "self_attn.q_proj.weight"), + ("attention.wk.weight", "self_attn.k_proj.weight"), + ("attention.wv.weight", "self_attn.v_proj.weight"), + ("attention.wo.weight", "self_attn.o_proj.weight"), + ("attention_norm.weight", "input_layernorm.weight"), + ("ffn_norm.weight", "post_attention_layernorm.weight"), + ("feed_forward.gate.weight", "block_sparse_moe.gate.weight"), + ] + sum([[ + (f"feed_forward.experts.{exp_no}.w1.weight", f"block_sparse_moe.experts.{exp_no}.w1.weight"), + (f"feed_forward.experts.{exp_no}.w2.weight", f"block_sparse_moe.experts.{exp_no}.w2.weight"), + (f"feed_forward.experts.{exp_no}.w3.weight", f"block_sparse_moe.experts.{exp_no}.w3.weight"), + ] for exp_no in range(num_experts)], start=[]) + else: + raise NotImplementedError(f"Unsupported model type {model_type}") + + for src_key, dst_key in src_dst_name_mapping: dst_key = f"model.layers.{i}." + dst_key src_key = f"llma.layers.{i}." + src_key value = merged_state_dict[src_key] @@ -282,7 +355,7 @@ def write_tokenizer(tokenizer_path: str, dest_dir: str) -> Any: def write_configs( - params: Dict[str, Any], dtype: torch.dtype, dest_dir: str, vocab_size: int + params: Dict[str, Any], dtype: torch.dtype, dest_dir: str, vocab_size: int, model_type: str ) -> None: def calculate_hidden_dim(): hidden_dim = params["dim"] * 4 @@ -295,56 +368,103 @@ def calculate_hidden_dim(): ) return hidden_dim - config = { - "architectures": [ - "LlamaForCausalLM" - ], - "bos_token_id": 1, - "eos_token_id": 2, - "hidden_act": "silu", - "hidden_size": params["dim"], - "initializer_range": 0.02, - "intermediate_size": calculate_hidden_dim(), - "max_position_embeddings": 2048, - "model_type": "llama", - "num_attention_heads": params["n_heads"], - "num_hidden_layers": params["n_layers"], - "num_key_value_heads": params.get("n_kv_heads", params["n_heads"]), - "pad_token_id": 0, - "pretraining_tp": 1, - "rms_norm_eps": params.get("norm_eps", 1e-5), - "rope_theta": params.get("rope_theta", 10000), - "rope_scaling": None if "rope_scaling" not in params else { - "type": "linear", - "factor": params["rope_scaling"], - }, - "tie_word_embeddings": False, - "torch_dtype": { - torch.float16: "float16", - torch.bfloat16: "bfloat16", - torch.float32: "float32", - }[dtype], - "transformers_version": transformers.__version__, - "use_cache": True, - "vocab_size": vocab_size - } + if model_type == 'llama': + config = { + "architectures": [ + "LlamaForCausalLM" + ], + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": params["dim"], + "initializer_range": 0.02, + "intermediate_size": calculate_hidden_dim(), + "max_position_embeddings": 2048, + "model_type": "llama", + "num_attention_heads": params["n_heads"], + "num_hidden_layers": params["n_layers"], + "num_key_value_heads": params.get("n_kv_heads", params["n_heads"]), + "pad_token_id": 0, + "pretraining_tp": 1, + "rms_norm_eps": params.get("norm_eps", 1e-5), + "rope_theta": params.get("rope_theta", 10000), + "rope_scaling": None if "rope_scaling" not in params else { + "type": "linear", + "factor": params["rope_scaling"], + }, + "tie_word_embeddings": False, + "torch_dtype": { + torch.float16: "float16", + torch.bfloat16: "bfloat16", + torch.float32: "float32", + }[dtype], + "transformers_version": transformers.__version__, + "use_cache": True, + "vocab_size": vocab_size + } + elif model_type == 'mixtral': + config = { + "architectures": [ + "MixtralForCausalLM" + ], + "bos_token_id": 1, + "eos_token_id": 2, + "attention_dropout": 0.0, + "hidden_act": "silu", + "hidden_size": params["dim"], + "initializer_range": 0.02, + "intermediate_size": params["hidden_dim"], + "max_position_embeddings": 32768, + "model_type": "mixtral", + "num_attention_heads": params["n_heads"], + "num_experts_per_tok": 2, + "num_hidden_layers": params["n_layers"], + "num_key_value_heads": params.get("n_kv_heads", params["n_heads"]), + "output_router_logits": False, + "rms_norm_eps": params.get("norm_eps", 1e-5), + "rope_theta": params.get("rope_theta", 10000), + "rope_scaling": None if "rope_scaling" not in params or params["rope_scaling"] is None else { + "type": "linear", + "factor": params["rope_scaling"], + }, + "router_aux_loss_coef": 0.02, + "sliding_window": None, + "tie_word_embeddings": False, + "torch_dtype": { + torch.float16: "float16", + torch.bfloat16: "bfloat16", + torch.float32: "float32", + }[dtype], + "transformers_version": transformers.__version__, + "use_cache": True, + "vocab_size": vocab_size + } with open(os.path.join(dest_dir, "config.json"), "w") as f: json.dump(config, f, indent=2) - generation_config = { - "_from_model_config": True, - "bos_token_id": 1, - "eos_token_id": 2, - "pad_token_id": 0, - "transformers_version": transformers.__version__, - } + if model_type == 'llama': + generation_config = { + "_from_model_config": True, + "bos_token_id": 1, + "eos_token_id": 2, + "pad_token_id": 0, + "transformers_version": transformers.__version__, + } + elif model_type == 'mixtral': + generation_config = { + "_from_model_config": True, + "bos_token_id": 1, + "eos_token_id": 2, + "transformers_version": transformers.__version__, + } with open(os.path.join(dest_dir, "generation_config.json"), "w") as f: json.dump(generation_config, f, indent=2) def write_hf_ckpt( hf_state_dict: List[Dict[str, torch.Tensor]], dest_dir: str, - tokenizer_path: str, params: Dict[str, Any], torch_dtype: torch.dtype + tokenizer_path: str, params: Dict[str, Any], torch_dtype: torch.dtype, + model_type: str = 'llama' ) -> None: os.makedirs(dest_dir, exist_ok=True) print("Writing model weights ...") @@ -352,7 +472,7 @@ def write_hf_ckpt( print("Writing tokenizer ...") tokenizer = write_tokenizer(tokenizer_path, dest_dir) print("Writing configs ...") - write_configs(params, torch_dtype, dest_dir, tokenizer.vocab_size) + write_configs(params, torch_dtype, dest_dir, tokenizer.vocab_size, model_type) def main() -> None: @@ -388,6 +508,10 @@ def main() -> None: help="Ignore unknown keys in the source checkpoint (the scripts will " "only give warnings); otherwise the conversion will fail." ) + parser.add_argument( + "--mixtral", action="store_true", + help="Whether the model is of Mixtral MoE architecture." + ) args = parser.parse_args() params = {} @@ -403,13 +527,13 @@ def main() -> None: print("Loading and merging source checkpoints ...") src_ckpt_merged = load_and_merge_tensor_parallel_weights( - args.src_weights_path, torch_dtype, args.ignore_unknown_keys + args.src_weights_path, torch_dtype, args.ignore_unknown_keys, 'mixtral' if args.mixtral else 'llama' ) print("Converting to HuggingFace format ...") - hf_ckpt = convert_merged_ckpt_to_hf(src_ckpt_merged, params) + hf_ckpt = convert_merged_ckpt_to_hf(src_ckpt_merged, params, 'mixtral' if args.mixtral else 'llama') print("Writing HuggingFace checkpoints to disk ...") write_hf_ckpt(hf_ckpt, args.dst_weights_path, args.tokenizer_path, params, - torch_dtype) + torch_dtype, 'mixtral' if args.mixtral else 'llama') print("Done!") diff --git a/accessory/tools/mixtral_moe_split_from_hf.py b/accessory/tools/mixtral_moe_split_from_hf.py new file mode 100644 index 00000000..b5307e2c --- /dev/null +++ b/accessory/tools/mixtral_moe_split_from_hf.py @@ -0,0 +1,266 @@ +""" + Rewrite from + - https://huggingface.co/Alpha-VLLM/MoE-Mixtral-7B-8Expert/blob/main/converted/split.py + - https://huggingface.co/Alpha-VLLM/MoE-Mixtral-7B-8Expert/blob/main/converted_sparse/split_sparse.py, + but we split from the huggingface version checkpoint +""" + +# meta info for mixtral moe split arch +config_json_data = { + "dim": 4096, + "hidden_dim": 14336, + "head_dim": 128, + "n_layers": 32, + "n_heads": 32, + "n_kv_heads": 8, + "vocab_size": 32000, + "norm_eps": 1e-05, + "rope_theta": 1000000, + "max_batch_size": 32, + "max_seq_len": 4096, + "moe": { + "num_experts_per_tok": 2, + "num_experts": 8 + }, + "rope_scaling": None +} + +meta_json_data = { + "llama_type": "mistral" +} +# end of meta info + + +# mapping huggingface ckpt tensor names to magnet ckpt tensor names +hf_name_to_magnet_name = [ + ["lm_head.weight", "llma.output.weight"], + ["llma.embed_tokens.weight", "llma.tok_embeddings.weight"], + ["llma.norm.weight", "llma.norm.weight"] +] + sum([[ + [f"llma.layers.{l}.block_sparse_moe.gate.weight", f"llma.layers.{l}.feed_forward.gate.weight"], + [f"llma.layers.{l}.input_layernorm.weight", f"llma.layers.{l}.attention_norm.weight"], + [f"llma.layers.{l}.post_attention_layernorm.weight", f"llma.layers.{l}.ffn_norm.weight"], + [f"llma.layers.{l}.self_attn.k_proj.weight", f"llma.layers.{l}.attention.wk.weight"], + [f"llma.layers.{l}.self_attn.q_proj.weight", f"llma.layers.{l}.attention.wq.weight"], + [f"llma.layers.{l}.self_attn.v_proj.weight", f"llma.layers.{l}.attention.wv.weight"], + [f"llma.layers.{l}.self_attn.o_proj.weight", f"llma.layers.{l}.attention.wo.weight"]] + + [[f"llma.layers.{l}.block_sparse_moe.experts.{e}.w1.weight", f"llma.layers.{l}.feed_forward.experts.{e}.w1.weight"] for e in range(8)] + + [[f"llma.layers.{l}.block_sparse_moe.experts.{e}.w2.weight", f"llma.layers.{l}.feed_forward.experts.{e}.w2.weight"] for e in range(8)] + + [[f"llma.layers.{l}.block_sparse_moe.experts.{e}.w3.weight", f"llma.layers.{l}.feed_forward.experts.{e}.w3.weight"] for e in range(8)] for l in range(32)], []) +hf_name_to_magnet_name = {item[0]: item[1] for item in hf_name_to_magnet_name} + + +weight_parallel_dim = {"llma.tok_embeddings.weight": 1, "llma.layers.0.attention.wq.weight": 0, + "llma.layers.0.attention.wq.bias": 0, "llma.layers.0.attention.wk.weight": 0, + "llma.layers.0.attention.wk.bias": 0, "llma.layers.0.attention.wv.weight": 0, + "llma.layers.0.attention.wv.bias": 0, "llma.layers.0.attention.wo.weight": 1, + "llma.layers.1.attention.wq.weight": 0, "llma.layers.1.attention.wq.bias": 0, + "llma.layers.1.attention.wk.weight": 0, "llma.layers.1.attention.wk.bias": 0, + "llma.layers.1.attention.wv.weight": 0, "llma.layers.1.attention.wv.bias": 0, + "llma.layers.1.attention.wo.weight": 1, "llma.layers.2.attention.wq.weight": 0, + "llma.layers.2.attention.wq.bias": 0, "llma.layers.2.attention.wk.weight": 0, + "llma.layers.2.attention.wk.bias": 0, "llma.layers.2.attention.wv.weight": 0, + "llma.layers.2.attention.wv.bias": 0, "llma.layers.2.attention.wo.weight": 1, + "llma.layers.3.attention.wq.weight": 0, "llma.layers.3.attention.wq.bias": 0, + "llma.layers.3.attention.wk.weight": 0, "llma.layers.3.attention.wk.bias": 0, + "llma.layers.3.attention.wv.weight": 0, "llma.layers.3.attention.wv.bias": 0, + "llma.layers.3.attention.wo.weight": 1, "llma.layers.4.attention.wq.weight": 0, + "llma.layers.4.attention.wq.bias": 0, "llma.layers.4.attention.wk.weight": 0, + "llma.layers.4.attention.wk.bias": 0, "llma.layers.4.attention.wv.weight": 0, + "llma.layers.4.attention.wv.bias": 0, "llma.layers.4.attention.wo.weight": 1, + "llma.layers.5.attention.wq.weight": 0, "llma.layers.5.attention.wq.bias": 0, + "llma.layers.5.attention.wk.weight": 0, "llma.layers.5.attention.wk.bias": 0, + "llma.layers.5.attention.wv.weight": 0, "llma.layers.5.attention.wv.bias": 0, + "llma.layers.5.attention.wo.weight": 1, "llma.layers.6.attention.wq.weight": 0, + "llma.layers.6.attention.wq.bias": 0, "llma.layers.6.attention.wk.weight": 0, + "llma.layers.6.attention.wk.bias": 0, "llma.layers.6.attention.wv.weight": 0, + "llma.layers.6.attention.wv.bias": 0, "llma.layers.6.attention.wo.weight": 1, + "llma.layers.7.attention.wq.weight": 0, "llma.layers.7.attention.wq.bias": 0, + "llma.layers.7.attention.wk.weight": 0, "llma.layers.7.attention.wk.bias": 0, + "llma.layers.7.attention.wv.weight": 0, "llma.layers.7.attention.wv.bias": 0, + "llma.layers.7.attention.wo.weight": 1, "llma.layers.8.attention.wq.weight": 0, + "llma.layers.8.attention.wq.bias": 0, "llma.layers.8.attention.wk.weight": 0, + "llma.layers.8.attention.wk.bias": 0, "llma.layers.8.attention.wv.weight": 0, + "llma.layers.8.attention.wv.bias": 0, "llma.layers.8.attention.wo.weight": 1, + "llma.layers.9.attention.wq.weight": 0, "llma.layers.9.attention.wq.bias": 0, + "llma.layers.9.attention.wk.weight": 0, "llma.layers.9.attention.wk.bias": 0, + "llma.layers.9.attention.wv.weight": 0, "llma.layers.9.attention.wv.bias": 0, + "llma.layers.9.attention.wo.weight": 1, "llma.layers.10.attention.wq.weight": 0, + "llma.layers.10.attention.wq.bias": 0, "llma.layers.10.attention.wk.weight": 0, + "llma.layers.10.attention.wk.bias": 0, "llma.layers.10.attention.wv.weight": 0, + "llma.layers.10.attention.wv.bias": 0, "llma.layers.10.attention.wo.weight": 1, + "llma.layers.11.attention.wq.weight": 0, "llma.layers.11.attention.wq.bias": 0, + "llma.layers.11.attention.wk.weight": 0, "llma.layers.11.attention.wk.bias": 0, + "llma.layers.11.attention.wv.weight": 0, "llma.layers.11.attention.wv.bias": 0, + "llma.layers.11.attention.wo.weight": 1, "llma.layers.12.attention.wq.weight": 0, + "llma.layers.12.attention.wq.bias": 0, "llma.layers.12.attention.wk.weight": 0, + "llma.layers.12.attention.wk.bias": 0, "llma.layers.12.attention.wv.weight": 0, + "llma.layers.12.attention.wv.bias": 0, "llma.layers.12.attention.wo.weight": 1, + "llma.layers.13.attention.wq.weight": 0, "llma.layers.13.attention.wq.bias": 0, + "llma.layers.13.attention.wk.weight": 0, "llma.layers.13.attention.wk.bias": 0, + "llma.layers.13.attention.wv.weight": 0, "llma.layers.13.attention.wv.bias": 0, + "llma.layers.13.attention.wo.weight": 1, "llma.layers.14.attention.wq.weight": 0, + "llma.layers.14.attention.wq.bias": 0, "llma.layers.14.attention.wk.weight": 0, + "llma.layers.14.attention.wk.bias": 0, "llma.layers.14.attention.wv.weight": 0, + "llma.layers.14.attention.wv.bias": 0, "llma.layers.14.attention.wo.weight": 1, + "llma.layers.15.attention.wq.weight": 0, "llma.layers.15.attention.wq.bias": 0, + "llma.layers.15.attention.wk.weight": 0, "llma.layers.15.attention.wk.bias": 0, + "llma.layers.15.attention.wv.weight": 0, "llma.layers.15.attention.wv.bias": 0, + "llma.layers.15.attention.wo.weight": 1, "llma.layers.16.attention.wq.weight": 0, + "llma.layers.16.attention.wq.bias": 0, "llma.layers.16.attention.wk.weight": 0, + "llma.layers.16.attention.wk.bias": 0, "llma.layers.16.attention.wv.weight": 0, + "llma.layers.16.attention.wv.bias": 0, "llma.layers.16.attention.wo.weight": 1, + "llma.layers.17.attention.wq.weight": 0, "llma.layers.17.attention.wq.bias": 0, + "llma.layers.17.attention.wk.weight": 0, "llma.layers.17.attention.wk.bias": 0, + "llma.layers.17.attention.wv.weight": 0, "llma.layers.17.attention.wv.bias": 0, + "llma.layers.17.attention.wo.weight": 1, "llma.layers.18.attention.wq.weight": 0, + "llma.layers.18.attention.wq.bias": 0, "llma.layers.18.attention.wk.weight": 0, + "llma.layers.18.attention.wk.bias": 0, "llma.layers.18.attention.wv.weight": 0, + "llma.layers.18.attention.wv.bias": 0, "llma.layers.18.attention.wo.weight": 1, + "llma.layers.19.attention.wq.weight": 0, "llma.layers.19.attention.wq.bias": 0, + "llma.layers.19.attention.wk.weight": 0, "llma.layers.19.attention.wk.bias": 0, + "llma.layers.19.attention.wv.weight": 0, "llma.layers.19.attention.wv.bias": 0, + "llma.layers.19.attention.wo.weight": 1, "llma.layers.20.attention.wq.weight": 0, + "llma.layers.20.attention.wq.bias": 0, "llma.layers.20.attention.wk.weight": 0, + "llma.layers.20.attention.wk.bias": 0, "llma.layers.20.attention.wv.weight": 0, + "llma.layers.20.attention.wv.bias": 0, "llma.layers.20.attention.wo.weight": 1, + "llma.layers.21.attention.wq.weight": 0, "llma.layers.21.attention.wq.bias": 0, + "llma.layers.21.attention.wk.weight": 0, "llma.layers.21.attention.wk.bias": 0, + "llma.layers.21.attention.wv.weight": 0, "llma.layers.21.attention.wv.bias": 0, + "llma.layers.21.attention.wo.weight": 1, "llma.layers.22.attention.wq.weight": 0, + "llma.layers.22.attention.wq.bias": 0, "llma.layers.22.attention.wk.weight": 0, + "llma.layers.22.attention.wk.bias": 0, "llma.layers.22.attention.wv.weight": 0, + "llma.layers.22.attention.wv.bias": 0, "llma.layers.22.attention.wo.weight": 1, + "llma.layers.23.attention.wq.weight": 0, "llma.layers.23.attention.wq.bias": 0, + "llma.layers.23.attention.wk.weight": 0, "llma.layers.23.attention.wk.bias": 0, + "llma.layers.23.attention.wv.weight": 0, "llma.layers.23.attention.wv.bias": 0, + "llma.layers.23.attention.wo.weight": 1, "llma.layers.24.attention.wq.weight": 0, + "llma.layers.24.attention.wq.bias": 0, "llma.layers.24.attention.wk.weight": 0, + "llma.layers.24.attention.wk.bias": 0, "llma.layers.24.attention.wv.weight": 0, + "llma.layers.24.attention.wv.bias": 0, "llma.layers.24.attention.wo.weight": 1, + "llma.layers.25.attention.wq.weight": 0, "llma.layers.25.attention.wq.bias": 0, + "llma.layers.25.attention.wk.weight": 0, "llma.layers.25.attention.wk.bias": 0, + "llma.layers.25.attention.wv.weight": 0, "llma.layers.25.attention.wv.bias": 0, + "llma.layers.25.attention.wo.weight": 1, "llma.layers.26.attention.wq.weight": 0, + "llma.layers.26.attention.wq.bias": 0, "llma.layers.26.attention.wk.weight": 0, + "llma.layers.26.attention.wk.bias": 0, "llma.layers.26.attention.wv.weight": 0, + "llma.layers.26.attention.wv.bias": 0, "llma.layers.26.attention.wo.weight": 1, + "llma.layers.27.attention.wq.weight": 0, "llma.layers.27.attention.wq.bias": 0, + "llma.layers.27.attention.wk.weight": 0, "llma.layers.27.attention.wk.bias": 0, + "llma.layers.27.attention.wv.weight": 0, "llma.layers.27.attention.wv.bias": 0, + "llma.layers.27.attention.wo.weight": 1, "llma.layers.28.attention.wq.weight": 0, + "llma.layers.28.attention.wq.bias": 0, "llma.layers.28.attention.wk.weight": 0, + "llma.layers.28.attention.wk.bias": 0, "llma.layers.28.attention.wv.weight": 0, + "llma.layers.28.attention.wv.bias": 0, "llma.layers.28.attention.wo.weight": 1, + "llma.layers.29.attention.wq.weight": 0, "llma.layers.29.attention.wq.bias": 0, + "llma.layers.29.attention.wk.weight": 0, "llma.layers.29.attention.wk.bias": 0, + "llma.layers.29.attention.wv.weight": 0, "llma.layers.29.attention.wv.bias": 0, + "llma.layers.29.attention.wo.weight": 1, "llma.layers.30.attention.wq.weight": 0, + "llma.layers.30.attention.wq.bias": 0, "llma.layers.30.attention.wk.weight": 0, + "llma.layers.30.attention.wk.bias": 0, "llma.layers.30.attention.wv.weight": 0, + "llma.layers.30.attention.wv.bias": 0, "llma.layers.30.attention.wo.weight": 1, + "llma.layers.31.attention.wq.weight": 0, "llma.layers.31.attention.wq.bias": 0, + "llma.layers.31.attention.wk.weight": 0, "llma.layers.31.attention.wk.bias": 0, + "llma.layers.31.attention.wv.weight": 0, "llma.layers.31.attention.wv.bias": 0, + "llma.layers.31.attention.wo.weight": 1, "llma.output.weight": 0, "llma.output.bias": 0} + +import argparse +import torch +from pathlib import Path +import shutil +import json + +parser = argparse.ArgumentParser() +parser.add_argument('in_folder', type=str, help='Model folder that stores the original ckpt') +parser.add_argument('out_folder', type=str, help='Model folder that stores the output ckpt') +parser.add_argument('--in_ckpt_source', type=str, default='hf', choices=['hf', 'magnet'], help='Input model folder source') +parser.add_argument('--convert_sparse', action='store_true', help='Convert to the sparse format') +if __name__ == '__main__': + args = parser.parse_args() + + Path(args.out_folder).mkdir(exist_ok=True) + + # save misc other things + shutil.copy(Path(args.in_folder) / 'tokenizer.model', Path(args.out_folder) / 'tokenizer.model') + with open(Path(args.out_folder) / 'meta.json', 'w') as f: + json.dump(meta_json_data, f) + with open(Path(args.out_folder) / 'config.json', 'w') as f: + json.dump(config_json_data, f) + + if args.in_ckpt_source == 'magnet': + ori = torch.load("consolidated.00.pth", map_location="cpu") + ori = {"llma." + key: val for key, val in ori.items()} + else: + ori = {} + + import json + import os.path as osp + import safetensors + from safetensors import safe_open + + with open(osp.join(args.in_folder, 'model.safetensors.index.json'), 'r') as f: + the_map = json.load(f) + print('metadata:', the_map['metadata']) + all_partitions = set(the_map['weight_map'].values()) + for now_partition in all_partitions: + with safe_open(osp.join(args.in_folder, now_partition), framework="pt", device="cpu") as f: + for key in f.keys(): + new_key = hf_name_to_magnet_name[key.replace('model.', 'llma.')] + ori[new_key] = f.get_tensor(key) + + if "wq" in new_key or "wk" in new_key: + print('transposing', new_key) + # to be compatible with HuggingFace's pos embed implementation. + head_dim = 128 + in_dim = ori[new_key].size(1) + ori[new_key] = ori[new_key].view( + -1, 2, head_dim // 2, in_dim, + ).transpose(1, 2).flatten(0, 2).contiguous() + + def func(rank=0): + shard_split_to = 8 + split_ckpt = {} + for key, ori_param in ori.items(): + if key in weight_parallel_dim: + split_ckpt[key] = torch.chunk(ori_param, shard_split_to, weight_parallel_dim[key])[ + rank % shard_split_to].clone() + if args.in_ckpt_source == 'hf': + split_ckpt[key] = split_ckpt[key].half() + if rank == 0: + print(f"chunk {key}") + else: + if not args.convert_sparse: + if "experts." in key and int(key.split("experts.")[1].split(".")[0]) != rank: + continue + else: + split_ckpt[key] = ori_param + if args.in_ckpt_source == 'hf': + split_ckpt[key] = split_ckpt[key].half() + if rank == 0: + print(f"inherit {key}") + else: + if "experts.0." in key: + weight_all_experts = [ori[key.replace("experts.0.", f"experts.{i}.")] for i in range(8)] + if "w2" in key: + weight_all_experts = [torch.transpose(_, 0, 1) for _ in weight_all_experts] + weight_this_rank = [torch.chunk(_, 8, dim=0)[rank] for _ in weight_all_experts] + weight_this_rank = torch.cat(weight_this_rank, dim=0).clone() + key = key.replace("experts.0.", "").replace(".weight", "") + split_ckpt[key] = weight_this_rank + if args.in_ckpt_source == 'hf': + split_ckpt[key] = split_ckpt[key].half() + print("expert key") + elif "experts" in key: + continue + else: + split_ckpt[key] = ori_param + if args.in_ckpt_source == 'hf': + split_ckpt[key] = split_ckpt[key].half() + if rank == 0: + print(f"inherit {key}") + print('saving at rank', rank) + torch.save({"model": split_ckpt}, osp.join(args.out_folder, f"consolidated.{rank:02d}-of-08.model.pth")) + + for r in range(8): + func(r) + diff --git a/requirements.txt b/requirements.txt index c04ea8aa..17e30c13 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,3 +19,4 @@ httpx[socks] einops regex h5py +safetensors