From 5e78e887369a12aaab450ec713507e646a44c45b Mon Sep 17 00:00:00 2001 From: kooshi <1934337+kooshi@users.noreply.github.com> Date: Fri, 27 Jun 2025 18:38:13 -0500 Subject: [PATCH 01/10] almost working --- convert_hf_to_gguf.py | 74 ++++++++++++++++++++++--------------- gguf-py/gguf/constants.py | 1 + gguf-py/gguf/gguf_writer.py | 3 ++ 3 files changed, 48 insertions(+), 30 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 12de60442a43e..1ad4d913435f5 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -842,14 +842,14 @@ def get_vocab_base_pre(self, tokenizer) -> str: def _set_vocab_none(self) -> None: self.gguf_writer.add_tokenizer_model("none") - def _set_vocab_gpt2(self) -> None: + def _set_vocab_gpt2(self, load_merges=True) -> None: tokens, toktypes, tokpre = self.get_vocab_base() self.gguf_writer.add_tokenizer_model("gpt2") self.gguf_writer.add_tokenizer_pre(tokpre) self.gguf_writer.add_token_list(tokens) self.gguf_writer.add_token_types(toktypes) - special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges) special_vocab.add_to_gguf(self.gguf_writer) def _set_vocab_qwen(self): @@ -6394,15 +6394,14 @@ def set_gguf_parameters(self): @ModelBase.register("HunYuanMoEV1ForCausalLM") -class HunYuanMoEModel(LlamaModel): +class HunYuanMoEModel(TextModel): model_arch = gguf.MODEL_ARCH.HUNYUAN_MOE - undo_permute = False def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def set_vocab(self): - self._set_vocab_gpt2() + self._set_vocab_gpt2(load_merges=False) def get_vocab_base(self) -> tuple[list[str], list[int], str]: tokens: list[str] = [] @@ -6411,52 +6410,41 @@ def get_vocab_base(self) -> tuple[list[str], list[int], str]: from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True) - # merge logic is copied from QwenModel, maybe incorrect merges = [] - vocab = {} mergeable_ranks = tokenizer.mergeable_ranks for token, rank in mergeable_ranks.items(): - vocab[QwenModel.token_bytes_to_string(token)] = rank if len(token) == 1: continue + # bpe() will decompose the token into its smallest parts and then + # re-merge them. If the token is a valid merge, bpe() will return + # the two pieces that were merged to create it. merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank) if len(merged) == 2: merges.append(' '.join(map(QwenModel.token_bytes_to_string, merged))) self.gguf_writer.add_token_merges(merges) + vocab_size = self.hparams["vocab_size"] + reverse_vocab = tokenizer.decoder - assert max(reverse_vocab.keys()) < tokenizer.vocab_size + assert max(reverse_vocab.keys()) < tokenizer.vocab_size, tokenizer.vocab_size == vocab_size tokpre = self.get_vocab_base_pre(tokenizer) - added_vocab = tokenizer.get_added_vocab() + special_token_ids = set(tokenizer.special_tokens.values()) - added_tokens_decoder = tokenizer.added_tokens_decoder + tokens: list[str] = [] + toktypes: list[int] = [] - for i in range(tokenizer.vocab_size): + for i in range(vocab_size): if i not in reverse_vocab: tokens.append(f"[PAD{i}]") toktypes.append(gguf.TokenType.UNUSED) else: - token: str = reverse_vocab[i] - if token in added_vocab: - # The tokenizer in llama.cpp assumes the CONTROL and USER_DEFINED tokens are pre-normalized. - # To avoid unexpected issues - we make sure to normalize non-normalized tokens - if not added_tokens_decoder[i].normalized: - previous_token = token - token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False)) - if previous_token != token: - logger.info(f"{repr(previous_token)} is encoded and decoded back to {repr(token)} using AutoTokenizer") - - if added_tokens_decoder[i].special or self.does_token_look_special(token): - toktypes.append(gguf.TokenType.CONTROL) - else: - # NOTE: this was added for Gemma. - # Encoding and decoding the tokens above isn't sufficient for this case. - token = token.replace(b"\xe2\x96\x81".decode("utf-8"), " ") # pre-normalize user-defined spaces - toktypes.append(gguf.TokenType.USER_DEFINED) + token = reverse_vocab[i] + tokens.append(token) + if i in special_token_ids: + toktypes.append(gguf.TokenType.CONTROL) else: toktypes.append(gguf.TokenType.NORMAL) - tokens.append(token) return tokens, toktypes, tokpre @@ -6474,6 +6462,25 @@ def set_gguf_parameters(self): assert all(topk == moe_topk[0] for topk in moe_topk) self.gguf_writer.add_expert_used_count(moe_topk[0]) + moe_shared_expert = self.hparams["num_shared_expert"] + assert all(n == moe_shared_expert[0] for n in moe_shared_expert) + self.gguf_writer.add_expert_shared_count(moe_shared_expert[0]) + + self.gguf_writer.add_qk_norm(self.hparams.get("use_qk_norm", True)) + + # Rope + rope_scaling = self.hparams.get("rope_scaling", {}) + if rope_scaling.get("type") == "dynamic": + logger.warning("Model uses 'dynamic' rope scaling, which is not yet supported in GGUF. " + "The resulting model may not work correctly with contexts longer than the training length.") + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) + else: + # Fallback for other potential scaling types + # This part is inherited from TextModel and will handle standard rope_theta + pass + + _experts: list[dict[str, Tensor]] | None = None + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: # process the experts separately if name.find("mlp.experts") != -1: @@ -6511,6 +6518,13 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [(self.map_tensor_name(name), data_torch)] + def prepare_tensors(self): + super().prepare_tensors() + if self._experts is not None: + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") + ###### CONVERSION LOGIC ###### diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index e0a4f688ac7a8..31eb66cef2de4 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -148,6 +148,7 @@ class Attention: VALUE_LENGTH_MLA = "{arch}.attention.value_length_mla" SHARED_KV_LAYERS = "{arch}.attention.shared_kv_layers" SLIDING_WINDOW_PATTERN = "{arch}.attention.sliding_window_pattern" + QK_NORM = "{arch}.attention.qk_norm" class Rope: DIMENSION_COUNT = "{arch}.rope.dimension_count" diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index d32cd479adb17..8096baf244da3 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -792,6 +792,9 @@ def add_group_norm_groups(self, value: int) -> None: def add_causal_attention(self, value: bool) -> None: self.add_bool(Keys.Attention.CAUSAL.format(arch=self.arch), value) + def add_qk_norm(self, value: bool) -> None: + self.add_bool(Keys.Attention.QK_NORM.format(arch=self.arch), value) + def add_q_lora_rank(self, length: int) -> None: self.add_uint32(Keys.Attention.Q_LORA_RANK.format(arch=self.arch), length) From d219580756922e2356ecf641521f2458a1a623fa Mon Sep 17 00:00:00 2001 From: kooshi <1934337+kooshi@users.noreply.github.com> Date: Fri, 27 Jun 2025 19:03:04 -0500 Subject: [PATCH 02/10] skip embed, fix bos --- convert_hf_to_gguf.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 1ad4d913435f5..0fe28b8396f41 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6399,25 +6399,25 @@ class HunYuanMoEModel(TextModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + # FIX for tied embeddings: Capture the token embeddings. + self._tok_embd = None def set_vocab(self): self._set_vocab_gpt2(load_merges=False) + # FIX for BOS token: Manually set the correct BOS token ID. + # The SpecialVocab helper gets incorrect id `bos_token_id: 1` from config.json. + self.gguf_writer.add_bos_token_id(127959) # <|bos|> def get_vocab_base(self) -> tuple[list[str], list[int], str]: - tokens: list[str] = [] - toktypes: list[int] = [] - from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True) + # Fake merges merges = [] mergeable_ranks = tokenizer.mergeable_ranks for token, rank in mergeable_ranks.items(): if len(token) == 1: continue - # bpe() will decompose the token into its smallest parts and then - # re-merge them. If the token is a valid merge, bpe() will return - # the two pieces that were merged to create it. merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank) if len(merged) == 2: merges.append(' '.join(map(QwenModel.token_bytes_to_string, merged))) @@ -6472,16 +6472,22 @@ def set_gguf_parameters(self): rope_scaling = self.hparams.get("rope_scaling", {}) if rope_scaling.get("type") == "dynamic": logger.warning("Model uses 'dynamic' rope scaling, which is not yet supported in GGUF. " - "The resulting model may not work correctly with contexts longer than the training length.") + "Long-context extrapolation will not work correctly. Setting rope scaling type to NONE.") self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) - else: - # Fallback for other potential scaling types - # This part is inherited from TextModel and will handle standard rope_theta - pass _experts: list[dict[str, Tensor]] | None = None def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # FIX for tied embeddings: Capture the token embeddings. + if name == "model.embed_tokens.weight": + self._tok_embd = data_torch.clone() + + # FIX for tied embeddings: Skip the lm_head if it's tied. + if name == "lm_head.weight": + if self.hparams.get("tie_word_embeddings", False): + logger.info("Skipping tied output layer 'lm_head.weight'") + return [] + # process the experts separately if name.find("mlp.experts") != -1: n_experts = self.hparams["num_experts"] From 0fd393087d3bfbcaf7db07aaa5ceb557ff8d1dfa Mon Sep 17 00:00:00 2001 From: kooshi <1934337+kooshi@users.noreply.github.com> Date: Fri, 27 Jun 2025 20:44:16 -0500 Subject: [PATCH 03/10] cleanup --- src/llama-arch.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 403dfe8fcc028..2584dd9c248a5 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -1672,7 +1672,6 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, From b19ecae2dc1fbd998c6e0cc14ca933f06b5a5af7 Mon Sep 17 00:00:00 2001 From: kooshi <1934337+kooshi@users.noreply.github.com> Date: Fri, 27 Jun 2025 20:54:56 -0500 Subject: [PATCH 04/10] yarn scaling --- convert_hf_to_gguf.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 0fe28b8396f41..62e93adc5ce5b 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6471,9 +6471,9 @@ def set_gguf_parameters(self): # Rope rope_scaling = self.hparams.get("rope_scaling", {}) if rope_scaling.get("type") == "dynamic": - logger.warning("Model uses 'dynamic' rope scaling, which is not yet supported in GGUF. " - "Long-context extrapolation will not work correctly. Setting rope scaling type to NONE.") - self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) + self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) + self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["max_position_embeddings"]) _experts: list[dict[str, Tensor]] | None = None From 245db1592003253cf90baf1698a5a45d8aa2acd1 Mon Sep 17 00:00:00 2001 From: kooshi <1934337+kooshi@users.noreply.github.com> Date: Fri, 27 Jun 2025 22:52:09 -0500 Subject: [PATCH 05/10] cleanup --- convert_hf_to_gguf.py | 73 +++++++++++++++++++------------------------ src/llama-graph.cpp | 7 ----- src/llama-model.cpp | 6 ++-- 3 files changed, 36 insertions(+), 50 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 62e93adc5ce5b..ecfcc8125a885 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6399,20 +6399,22 @@ class HunYuanMoEModel(TextModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - # FIX for tied embeddings: Capture the token embeddings. + # For handling tied embeddings self._tok_embd = None def set_vocab(self): - self._set_vocab_gpt2(load_merges=False) - # FIX for BOS token: Manually set the correct BOS token ID. - # The SpecialVocab helper gets incorrect id `bos_token_id: 1` from config.json. - self.gguf_writer.add_bos_token_id(127959) # <|bos|> - - def get_vocab_base(self) -> tuple[list[str], list[int], str]: + """ + A self-contained vocab implementation for the HunYuan tiktoken-based tokenizer. + This method correctly generates tokens, types, and the required "fake" merges + to satisfy the llama.cpp GGUF loader. + """ from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True) - # Fake merges + # 1. Get the pre-tokenizer identifier hash + tokpre = self.get_vocab_base_pre(tokenizer) + + # 2. Reverse-engineer the merges list from mergeable_ranks merges = [] mergeable_ranks = tokenizer.mergeable_ranks for token, rank in mergeable_ranks.items(): @@ -6421,19 +6423,13 @@ def get_vocab_base(self) -> tuple[list[str], list[int], str]: merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank) if len(merged) == 2: merges.append(' '.join(map(QwenModel.token_bytes_to_string, merged))) - self.gguf_writer.add_token_merges(merges) + # 3. Generate the tokens and toktypes lists vocab_size = self.hparams["vocab_size"] - reverse_vocab = tokenizer.decoder - assert max(reverse_vocab.keys()) < tokenizer.vocab_size, tokenizer.vocab_size == vocab_size - - tokpre = self.get_vocab_base_pre(tokenizer) special_token_ids = set(tokenizer.special_tokens.values()) - tokens: list[str] = [] toktypes: list[int] = [] - for i in range(vocab_size): if i not in reverse_vocab: tokens.append(f"[PAD{i}]") @@ -6446,30 +6442,42 @@ def get_vocab_base(self) -> tuple[list[str], list[int], str]: else: toktypes.append(gguf.TokenType.NORMAL) - return tokens, toktypes, tokpre + # 4. Write all vocab-related fields to the GGUF writer + self.gguf_writer.add_tokenizer_model("gpt2") + self.gguf_writer.add_tokenizer_pre(tokpre) + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_types(toktypes) + self.gguf_writer.add_token_merges(merges) + + # 5. Add special tokens and chat templates + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False) + special_vocab.add_to_gguf(self.gguf_writer) + # FIX for BOS token: Manually set the correct BOS token ID. + self.gguf_writer.add_bos_token_id(127959) # <|bos|> def set_gguf_parameters(self): super().set_gguf_parameters() + hparams = self.hparams - self.gguf_writer.add_expert_count(self.hparams["num_experts"]) - self.gguf_writer.add_expert_shared_feed_forward_length(self.hparams["intermediate_size"]) + self.gguf_writer.add_expert_count(hparams["num_experts"]) + self.gguf_writer.add_expert_shared_feed_forward_length(hparams["intermediate_size"]) - moe_intermediate_size = self.hparams["moe_intermediate_size"] + moe_intermediate_size = hparams["moe_intermediate_size"] assert all(n == moe_intermediate_size[0] for n in moe_intermediate_size) self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size[0]) - moe_topk = self.hparams["moe_topk"] + moe_topk = hparams["moe_topk"] assert all(topk == moe_topk[0] for topk in moe_topk) self.gguf_writer.add_expert_used_count(moe_topk[0]) - moe_shared_expert = self.hparams["num_shared_expert"] + moe_shared_expert = hparams["num_shared_expert"] assert all(n == moe_shared_expert[0] for n in moe_shared_expert) self.gguf_writer.add_expert_shared_count(moe_shared_expert[0]) - self.gguf_writer.add_qk_norm(self.hparams.get("use_qk_norm", True)) + self.gguf_writer.add_qk_norm(hparams.get("use_qk_norm", True)) # Rope - rope_scaling = self.hparams.get("rope_scaling", {}) + rope_scaling = hparams.get("rope_scaling", {}) if rope_scaling.get("type") == "dynamic": self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) @@ -6478,50 +6486,33 @@ def set_gguf_parameters(self): _experts: list[dict[str, Tensor]] | None = None def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - # FIX for tied embeddings: Capture the token embeddings. if name == "model.embed_tokens.weight": self._tok_embd = data_torch.clone() - - # FIX for tied embeddings: Skip the lm_head if it's tied. if name == "lm_head.weight": if self.hparams.get("tie_word_embeddings", False): logger.info("Skipping tied output layer 'lm_head.weight'") return [] - - # process the experts separately if name.find("mlp.experts") != -1: n_experts = self.hparams["num_experts"] assert bid is not None - - tensors: list[tuple[str, Tensor]] = [] - if self._experts is None: self._experts = [{} for _ in range(self.block_count)] - self._experts[bid][name] = data_torch - if len(self._experts[bid]) >= n_experts * 3: - # merge the experts into a single 3d tensor + tensors: list[tuple[str, Tensor]] = [] for w_name in ["down_proj", "gate_proj", "up_proj"]: datas: list[Tensor] = [] - for xid in range(n_experts): ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight" datas.append(self._experts[bid][ename]) del self._experts[bid][ename] - data_torch = torch.stack(datas, dim=0) - merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" - new_name = self.map_tensor_name(merged_name) - tensors.append((new_name, data_torch)) - return tensors else: return [] - return [(self.map_tensor_name(name), data_torch)] def prepare_tensors(self): diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index a299c89ecc7e4..71ee431a977ba 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -705,13 +705,6 @@ ggml_tensor * llm_graph_context::build_moe_ffn( ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] cb(weights, "ffn_moe_weights", il); - if (arch == LLM_ARCH_HUNYUAN_MOE) { - weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_expert_used, n_tokens] - weights = ggml_div(ctx0, weights, ggml_sum_rows(ctx0, weights)); // [1, n_tokens] - weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens); // [1, n_expert_used, n_tokens] - cb(weights, "ffn_moe_weights_scaled", il); - } - if (norm_w) { weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 86382aba5934c..3a25a2ccad377 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -14432,8 +14432,10 @@ struct llm_build_hunyuan_moe : public llm_graph_context { model.layers[il].ffn_down_exps, nullptr, n_expert, n_expert_used, - LLM_FFN_SILU, false, - false, 0.0, + LLM_FFN_SILU, + true, // norm_topk_prob + false, + 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); cb(cur_moe, "ffn_moe_out", il); From 8fd547bd514abb65cd35c56c94e52c9400733653 Mon Sep 17 00:00:00 2001 From: kooshi <1934337+kooshi@users.noreply.github.com> Date: Sat, 28 Jun 2025 17:12:38 -0500 Subject: [PATCH 06/10] failed token fix --- convert_hf_to_gguf.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index ecfcc8125a885..cd8dfd1652e01 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6416,18 +6416,21 @@ def set_vocab(self): # 2. Reverse-engineer the merges list from mergeable_ranks merges = [] + vocab = {} mergeable_ranks = tokenizer.mergeable_ranks for token, rank in mergeable_ranks.items(): + #vocab[QwenModel.token_bytes_to_string(token)] = rank if len(token) == 1: continue merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank) - if len(merged) == 2: + if len(merged) == 2: #todo this is an assert in Qwen, why? merges.append(' '.join(map(QwenModel.token_bytes_to_string, merged))) # 3. Generate the tokens and toktypes lists vocab_size = self.hparams["vocab_size"] - reverse_vocab = tokenizer.decoder special_token_ids = set(tokenizer.special_tokens.values()) + reverse_vocab = tokenizer.decoder + #reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **special_token_ids}.items()} tokens: list[str] = [] toktypes: list[int] = [] for i in range(vocab_size): From b20bd2639a9b42be5d90b496ee8a1ac966b958fd Mon Sep 17 00:00:00 2001 From: kooshi <1934337+kooshi@users.noreply.github.com> Date: Sat, 28 Jun 2025 19:44:31 -0500 Subject: [PATCH 07/10] tokenization working --- convert_hf_to_gguf.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index cd8dfd1652e01..9d39e0597ad17 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -346,6 +346,8 @@ def prepare_tensors(self): data_qtype = gguf.GGMLQuantizationType.BF16 elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0: data_qtype = gguf.GGMLQuantizationType.Q8_0 + elif self.ftype == gguf.LlamaFileType.MOSTLY_Q4_0: + data_qtype = gguf.GGMLQuantizationType.Q4_0 elif self.ftype == gguf.LlamaFileType.MOSTLY_TQ1_0: data_qtype = gguf.GGMLQuantizationType.TQ1_0 elif self.ftype == gguf.LlamaFileType.MOSTLY_TQ2_0: @@ -6419,7 +6421,7 @@ def set_vocab(self): vocab = {} mergeable_ranks = tokenizer.mergeable_ranks for token, rank in mergeable_ranks.items(): - #vocab[QwenModel.token_bytes_to_string(token)] = rank + vocab[QwenModel.token_bytes_to_string(token)] = rank if len(token) == 1: continue merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank) @@ -6428,9 +6430,8 @@ def set_vocab(self): # 3. Generate the tokens and toktypes lists vocab_size = self.hparams["vocab_size"] - special_token_ids = set(tokenizer.special_tokens.values()) - reverse_vocab = tokenizer.decoder - #reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **special_token_ids}.items()} + special_tokens = tokenizer.special_tokens + reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **special_tokens}.items()} tokens: list[str] = [] toktypes: list[int] = [] for i in range(vocab_size): @@ -6440,7 +6441,7 @@ def set_vocab(self): else: token = reverse_vocab[i] tokens.append(token) - if i in special_token_ids: + if i in special_tokens.values(): toktypes.append(gguf.TokenType.CONTROL) else: toktypes.append(gguf.TokenType.NORMAL) @@ -6614,7 +6615,7 @@ def parse_args() -> argparse.Namespace: help="path to write to; default: based on input. {ftype} will be replaced by the outtype.", ) parser.add_argument( - "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "tq1_0", "tq2_0", "auto"], default="f16", + "--outtype", type=str, choices=["f32", "f16", "bf16", "q4_0", "q8_0", "tq1_0", "tq2_0", "auto"], default="f16", help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, tq1_0 or tq2_0 for ternary, and auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type", ) parser.add_argument( @@ -6746,6 +6747,7 @@ def main() -> None: "f32": gguf.LlamaFileType.ALL_F32, "f16": gguf.LlamaFileType.MOSTLY_F16, "bf16": gguf.LlamaFileType.MOSTLY_BF16, + "q4_0": gguf.LlamaFileType.MOSTLY_Q4_0, "q8_0": gguf.LlamaFileType.MOSTLY_Q8_0, "tq1_0": gguf.LlamaFileType.MOSTLY_TQ1_0, "tq2_0": gguf.LlamaFileType.MOSTLY_TQ2_0, From 1221d944ca3b5e97426c5b31aa9d053ac0a323c7 Mon Sep 17 00:00:00 2001 From: kooshi <1934337+kooshi@users.noreply.github.com> Date: Sun, 29 Jun 2025 08:52:41 -0500 Subject: [PATCH 08/10] cleanup and pr changes --- convert_hf_to_gguf.py | 25 +++++++++++++++---------- gguf-py/gguf/constants.py | 1 - gguf-py/gguf/gguf_writer.py | 3 --- 3 files changed, 15 insertions(+), 14 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 9d39e0597ad17..d4be8f96248a4 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -844,14 +844,14 @@ def get_vocab_base_pre(self, tokenizer) -> str: def _set_vocab_none(self) -> None: self.gguf_writer.add_tokenizer_model("none") - def _set_vocab_gpt2(self, load_merges=True) -> None: + def _set_vocab_gpt2(self) -> None: tokens, toktypes, tokpre = self.get_vocab_base() self.gguf_writer.add_tokenizer_model("gpt2") self.gguf_writer.add_tokenizer_pre(tokpre) self.gguf_writer.add_token_list(tokens) self.gguf_writer.add_token_types(toktypes) - special_vocab = gguf.SpecialVocab(self.dir_model, load_merges) + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) special_vocab.add_to_gguf(self.gguf_writer) def _set_vocab_qwen(self): @@ -6405,11 +6405,6 @@ def __init__(self, *args, **kwargs): self._tok_embd = None def set_vocab(self): - """ - A self-contained vocab implementation for the HunYuan tiktoken-based tokenizer. - This method correctly generates tokens, types, and the required "fake" merges - to satisfy the llama.cpp GGUF loader. - """ from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True) @@ -6456,7 +6451,7 @@ def set_vocab(self): # 5. Add special tokens and chat templates special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False) special_vocab.add_to_gguf(self.gguf_writer) - # FIX for BOS token: Manually set the correct BOS token ID. + # FIX for BOS token: Overwrite incorrect id read from config.json self.gguf_writer.add_bos_token_id(127959) # <|bos|> def set_gguf_parameters(self): @@ -6478,11 +6473,11 @@ def set_gguf_parameters(self): assert all(n == moe_shared_expert[0] for n in moe_shared_expert) self.gguf_writer.add_expert_shared_count(moe_shared_expert[0]) - self.gguf_writer.add_qk_norm(hparams.get("use_qk_norm", True)) - # Rope rope_scaling = hparams.get("rope_scaling", {}) if rope_scaling.get("type") == "dynamic": + # Not sure if YARN is correct here, and the factor in the config is only 1 anyway + # but the release claims to scale to 256k, which would be a factor of 8 self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["max_position_embeddings"]) @@ -6492,31 +6487,41 @@ def set_gguf_parameters(self): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: if name == "model.embed_tokens.weight": self._tok_embd = data_torch.clone() + if name == "lm_head.weight": if self.hparams.get("tie_word_embeddings", False): logger.info("Skipping tied output layer 'lm_head.weight'") return [] + if name.find("mlp.experts") != -1: n_experts = self.hparams["num_experts"] assert bid is not None + if self._experts is None: self._experts = [{} for _ in range(self.block_count)] + self._experts[bid][name] = data_torch + if len(self._experts[bid]) >= n_experts * 3: + # merge the experts into a single 3d tensor tensors: list[tuple[str, Tensor]] = [] for w_name in ["down_proj", "gate_proj", "up_proj"]: datas: list[Tensor] = [] + for xid in range(n_experts): ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight" datas.append(self._experts[bid][ename]) del self._experts[bid][ename] + data_torch = torch.stack(datas, dim=0) merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" new_name = self.map_tensor_name(merged_name) tensors.append((new_name, data_torch)) + return tensors else: return [] + return [(self.map_tensor_name(name), data_torch)] def prepare_tensors(self): diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 31eb66cef2de4..e0a4f688ac7a8 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -148,7 +148,6 @@ class Attention: VALUE_LENGTH_MLA = "{arch}.attention.value_length_mla" SHARED_KV_LAYERS = "{arch}.attention.shared_kv_layers" SLIDING_WINDOW_PATTERN = "{arch}.attention.sliding_window_pattern" - QK_NORM = "{arch}.attention.qk_norm" class Rope: DIMENSION_COUNT = "{arch}.rope.dimension_count" diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 8096baf244da3..d32cd479adb17 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -792,9 +792,6 @@ def add_group_norm_groups(self, value: int) -> None: def add_causal_attention(self, value: bool) -> None: self.add_bool(Keys.Attention.CAUSAL.format(arch=self.arch), value) - def add_qk_norm(self, value: bool) -> None: - self.add_bool(Keys.Attention.QK_NORM.format(arch=self.arch), value) - def add_q_lora_rank(self, length: int) -> None: self.add_uint32(Keys.Attention.Q_LORA_RANK.format(arch=self.arch), length) From 5471f5acf2ca4d8ca7500b4aa63616d6b9fb5368 Mon Sep 17 00:00:00 2001 From: kooshi <1934337+kooshi@users.noreply.github.com> Date: Sun, 29 Jun 2025 08:59:36 -0500 Subject: [PATCH 09/10] vocab_size sanity check --- convert_hf_to_gguf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index d4be8f96248a4..88ebf4fa6d4f4 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6425,6 +6425,7 @@ def set_vocab(self): # 3. Generate the tokens and toktypes lists vocab_size = self.hparams["vocab_size"] + assert tokenizer.vocab_size == vocab_size special_tokens = tokenizer.special_tokens reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **special_tokens}.items()} tokens: list[str] = [] From 46c8b70cbc7346db95e45ebae4f1e0c68a9b8d86 Mon Sep 17 00:00:00 2001 From: kooshi <1934337+kooshi@users.noreply.github.com> Date: Sun, 29 Jun 2025 19:57:01 -0500 Subject: [PATCH 10/10] ntk alpha generic --- convert_hf_to_gguf.py | 21 ++++++++++++++++----- src/llama-model.cpp | 5 ----- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 88ebf4fa6d4f4..26fcfef20fa1b 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6477,11 +6477,22 @@ def set_gguf_parameters(self): # Rope rope_scaling = hparams.get("rope_scaling", {}) if rope_scaling.get("type") == "dynamic": - # Not sure if YARN is correct here, and the factor in the config is only 1 anyway - # but the release claims to scale to 256k, which would be a factor of 8 - self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) - self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) - self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["max_position_embeddings"]) + # HunYuan uses NTK Aware Alpha based scaling. Original implementation: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + # 1000 corresponds to a usable context length of 256k (https://github.com/Tencent-Hunyuan/Hunyuan-A13B/blob/main/report/Hunyuan_A13B_Technical_Report.pdf) + alpha = rope_scaling.get("alpha", 1000) + base = hparams.get("rope_theta", 10000.0) + dim = (hparams["hidden_size"] // hparams["num_attention_heads"]) # 128 + scaled_base = base * (alpha ** (dim / (dim-2))) # 10000 * (1000 ** (128 / 126)) = 11158839.9251 + self.gguf_writer.add_rope_freq_base(scaled_base) + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) + self.gguf_writer.add_rope_scaling_factor(1) + #There is no consistent way to calculate ctx from alpha, and the config is incorrectly set to 32k + self.gguf_writer.add_rope_scaling_orig_ctx_len(256 * 1024) # 256k context length + self.gguf_writer.add_context_length(256 * 1024) # 256k context length + + # if any of our assumptions about the values are wrong, something has changed and this may need to be updated + assert alpha == 1000 and base == 10000.0 and dim == 128 and self.hparams["max_position_embeddings"] in [32 * 1024, 256 * 1024] , \ + "HunYuan dynamic RoPE scaling assumptions changed, please update the logic or context length manually" _experts: list[dict[str, Tensor]] | None = None diff --git a/src/llama-model.cpp b/src/llama-model.cpp index d881821d46ae1..4992f81c0e46d 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1511,11 +1511,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp); - // TODO: read from gguf - float n_dim = hparams.n_embd_head_k; - float alpha = 1000.0f; // NTK-Aware - hparams.rope_freq_base_train = 10000.0f * std::powf(alpha, n_dim / (n_dim - 2.0f)); - switch (hparams.n_layer) { case 32: type = LLM_TYPE_A13B; break; default: type = LLM_TYPE_UNKNOWN;