diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 12de60442a43e..26fcfef20fa1b 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -346,6 +346,8 @@ def prepare_tensors(self): data_qtype = gguf.GGMLQuantizationType.BF16 elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0: data_qtype = gguf.GGMLQuantizationType.Q8_0 + elif self.ftype == gguf.LlamaFileType.MOSTLY_Q4_0: + data_qtype = gguf.GGMLQuantizationType.Q4_0 elif self.ftype == gguf.LlamaFileType.MOSTLY_TQ1_0: data_qtype = gguf.GGMLQuantizationType.TQ1_0 elif self.ftype == gguf.LlamaFileType.MOSTLY_TQ2_0: @@ -6394,24 +6396,22 @@ def set_gguf_parameters(self): @ModelBase.register("HunYuanMoEV1ForCausalLM") -class HunYuanMoEModel(LlamaModel): +class HunYuanMoEModel(TextModel): model_arch = gguf.MODEL_ARCH.HUNYUAN_MOE - undo_permute = False def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + # For handling tied embeddings + self._tok_embd = None def set_vocab(self): - self._set_vocab_gpt2() - - def get_vocab_base(self) -> tuple[list[str], list[int], str]: - tokens: list[str] = [] - toktypes: list[int] = [] - from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True) - # merge logic is copied from QwenModel, maybe incorrect + # 1. Get the pre-tokenizer identifier hash + tokpre = self.get_vocab_base_pre(tokenizer) + + # 2. Reverse-engineer the merges list from mergeable_ranks merges = [] vocab = {} mergeable_ranks = tokenizer.mergeable_ranks @@ -6420,68 +6420,95 @@ def get_vocab_base(self) -> tuple[list[str], list[int], str]: if len(token) == 1: continue merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank) - if len(merged) == 2: + if len(merged) == 2: #todo this is an assert in Qwen, why? merges.append(' '.join(map(QwenModel.token_bytes_to_string, merged))) - self.gguf_writer.add_token_merges(merges) - reverse_vocab = tokenizer.decoder - assert max(reverse_vocab.keys()) < tokenizer.vocab_size - - tokpre = self.get_vocab_base_pre(tokenizer) - added_vocab = tokenizer.get_added_vocab() - - added_tokens_decoder = tokenizer.added_tokens_decoder - - for i in range(tokenizer.vocab_size): + # 3. Generate the tokens and toktypes lists + vocab_size = self.hparams["vocab_size"] + assert tokenizer.vocab_size == vocab_size + special_tokens = tokenizer.special_tokens + reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **special_tokens}.items()} + tokens: list[str] = [] + toktypes: list[int] = [] + for i in range(vocab_size): if i not in reverse_vocab: tokens.append(f"[PAD{i}]") toktypes.append(gguf.TokenType.UNUSED) else: - token: str = reverse_vocab[i] - if token in added_vocab: - # The tokenizer in llama.cpp assumes the CONTROL and USER_DEFINED tokens are pre-normalized. - # To avoid unexpected issues - we make sure to normalize non-normalized tokens - if not added_tokens_decoder[i].normalized: - previous_token = token - token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False)) - if previous_token != token: - logger.info(f"{repr(previous_token)} is encoded and decoded back to {repr(token)} using AutoTokenizer") - - if added_tokens_decoder[i].special or self.does_token_look_special(token): - toktypes.append(gguf.TokenType.CONTROL) - else: - # NOTE: this was added for Gemma. - # Encoding and decoding the tokens above isn't sufficient for this case. - token = token.replace(b"\xe2\x96\x81".decode("utf-8"), " ") # pre-normalize user-defined spaces - toktypes.append(gguf.TokenType.USER_DEFINED) + token = reverse_vocab[i] + tokens.append(token) + if i in special_tokens.values(): + toktypes.append(gguf.TokenType.CONTROL) else: toktypes.append(gguf.TokenType.NORMAL) - tokens.append(token) - return tokens, toktypes, tokpre + # 4. Write all vocab-related fields to the GGUF writer + self.gguf_writer.add_tokenizer_model("gpt2") + self.gguf_writer.add_tokenizer_pre(tokpre) + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_types(toktypes) + self.gguf_writer.add_token_merges(merges) + + # 5. Add special tokens and chat templates + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False) + special_vocab.add_to_gguf(self.gguf_writer) + # FIX for BOS token: Overwrite incorrect id read from config.json + self.gguf_writer.add_bos_token_id(127959) # <|bos|> def set_gguf_parameters(self): super().set_gguf_parameters() + hparams = self.hparams - self.gguf_writer.add_expert_count(self.hparams["num_experts"]) - self.gguf_writer.add_expert_shared_feed_forward_length(self.hparams["intermediate_size"]) + self.gguf_writer.add_expert_count(hparams["num_experts"]) + self.gguf_writer.add_expert_shared_feed_forward_length(hparams["intermediate_size"]) - moe_intermediate_size = self.hparams["moe_intermediate_size"] + moe_intermediate_size = hparams["moe_intermediate_size"] assert all(n == moe_intermediate_size[0] for n in moe_intermediate_size) self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size[0]) - moe_topk = self.hparams["moe_topk"] + moe_topk = hparams["moe_topk"] assert all(topk == moe_topk[0] for topk in moe_topk) self.gguf_writer.add_expert_used_count(moe_topk[0]) + moe_shared_expert = hparams["num_shared_expert"] + assert all(n == moe_shared_expert[0] for n in moe_shared_expert) + self.gguf_writer.add_expert_shared_count(moe_shared_expert[0]) + + # Rope + rope_scaling = hparams.get("rope_scaling", {}) + if rope_scaling.get("type") == "dynamic": + # HunYuan uses NTK Aware Alpha based scaling. Original implementation: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + # 1000 corresponds to a usable context length of 256k (https://github.com/Tencent-Hunyuan/Hunyuan-A13B/blob/main/report/Hunyuan_A13B_Technical_Report.pdf) + alpha = rope_scaling.get("alpha", 1000) + base = hparams.get("rope_theta", 10000.0) + dim = (hparams["hidden_size"] // hparams["num_attention_heads"]) # 128 + scaled_base = base * (alpha ** (dim / (dim-2))) # 10000 * (1000 ** (128 / 126)) = 11158839.9251 + self.gguf_writer.add_rope_freq_base(scaled_base) + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) + self.gguf_writer.add_rope_scaling_factor(1) + #There is no consistent way to calculate ctx from alpha, and the config is incorrectly set to 32k + self.gguf_writer.add_rope_scaling_orig_ctx_len(256 * 1024) # 256k context length + self.gguf_writer.add_context_length(256 * 1024) # 256k context length + + # if any of our assumptions about the values are wrong, something has changed and this may need to be updated + assert alpha == 1000 and base == 10000.0 and dim == 128 and self.hparams["max_position_embeddings"] in [32 * 1024, 256 * 1024] , \ + "HunYuan dynamic RoPE scaling assumptions changed, please update the logic or context length manually" + + _experts: list[dict[str, Tensor]] | None = None + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - # process the experts separately + if name == "model.embed_tokens.weight": + self._tok_embd = data_torch.clone() + + if name == "lm_head.weight": + if self.hparams.get("tie_word_embeddings", False): + logger.info("Skipping tied output layer 'lm_head.weight'") + return [] + if name.find("mlp.experts") != -1: n_experts = self.hparams["num_experts"] assert bid is not None - tensors: list[tuple[str, Tensor]] = [] - if self._experts is None: self._experts = [{} for _ in range(self.block_count)] @@ -6489,6 +6516,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if len(self._experts[bid]) >= n_experts * 3: # merge the experts into a single 3d tensor + tensors: list[tuple[str, Tensor]] = [] for w_name in ["down_proj", "gate_proj", "up_proj"]: datas: list[Tensor] = [] @@ -6498,11 +6526,8 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter del self._experts[bid][ename] data_torch = torch.stack(datas, dim=0) - merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" - new_name = self.map_tensor_name(merged_name) - tensors.append((new_name, data_torch)) return tensors @@ -6511,6 +6536,13 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [(self.map_tensor_name(name), data_torch)] + def prepare_tensors(self): + super().prepare_tensors() + if self._experts is not None: + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") + ###### CONVERSION LOGIC ###### @@ -6600,7 +6632,7 @@ def parse_args() -> argparse.Namespace: help="path to write to; default: based on input. {ftype} will be replaced by the outtype.", ) parser.add_argument( - "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "tq1_0", "tq2_0", "auto"], default="f16", + "--outtype", type=str, choices=["f32", "f16", "bf16", "q4_0", "q8_0", "tq1_0", "tq2_0", "auto"], default="f16", help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, tq1_0 or tq2_0 for ternary, and auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type", ) parser.add_argument( @@ -6732,6 +6764,7 @@ def main() -> None: "f32": gguf.LlamaFileType.ALL_F32, "f16": gguf.LlamaFileType.MOSTLY_F16, "bf16": gguf.LlamaFileType.MOSTLY_BF16, + "q4_0": gguf.LlamaFileType.MOSTLY_Q4_0, "q8_0": gguf.LlamaFileType.MOSTLY_Q8_0, "tq1_0": gguf.LlamaFileType.MOSTLY_TQ1_0, "tq2_0": gguf.LlamaFileType.MOSTLY_TQ2_0, diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 403dfe8fcc028..2584dd9c248a5 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -1672,7 +1672,6 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index a299c89ecc7e4..71ee431a977ba 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -705,13 +705,6 @@ ggml_tensor * llm_graph_context::build_moe_ffn( ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] cb(weights, "ffn_moe_weights", il); - if (arch == LLM_ARCH_HUNYUAN_MOE) { - weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_expert_used, n_tokens] - weights = ggml_div(ctx0, weights, ggml_sum_rows(ctx0, weights)); // [1, n_tokens] - weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens); // [1, n_expert_used, n_tokens] - cb(weights, "ffn_moe_weights_scaled", il); - } - if (norm_w) { weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 172a1d7409caa..4992f81c0e46d 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1511,11 +1511,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp); - // TODO: read from gguf - float n_dim = hparams.n_embd_head_k; - float alpha = 1000.0f; // NTK-Aware - hparams.rope_freq_base_train = 10000.0f * std::powf(alpha, n_dim / (n_dim - 2.0f)); - switch (hparams.n_layer) { case 32: type = LLM_TYPE_A13B; break; default: type = LLM_TYPE_UNKNOWN; @@ -14437,8 +14432,10 @@ struct llm_build_hunyuan_moe : public llm_graph_context { model.layers[il].ffn_down_exps, nullptr, n_expert, n_expert_used, - LLM_FFN_SILU, false, - false, 0.0, + LLM_FFN_SILU, + true, // norm_topk_prob + false, + 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); cb(cur_moe, "ffn_moe_out", il);