From acd76d06bde031e5e9a149af0e179036892608eb Mon Sep 17 00:00:00 2001 From: tamarPal Date: Thu, 6 Nov 2025 13:37:55 +0200 Subject: [PATCH 1/8] feat: Add Megrez-MoE architecture support Implements complete support for Megrez-MoE (Mixture of Experts) models: - Add LLM_ARCH_MEGREZ_MOE architecture enum and mappings - Implement build_mergez_moe_ffn() with sigmoid+bias gating - Add llm_build_megrez_moe class for full model graph construction - Support 31-layer architecture (layer 0: dense FFN, layers 1-30: MoE) - Implement expert sharing pattern with 64 experts, 6 used per token, 4 shared - Load all model hyperparameters and 372 tensors correctly - Configure NEOX RoPE type for proper positional encoding Tested with Megrez2-3x7B-A3B_Q4_K_M.gguf model. All 39 llama.cpp tests pass successfully. Output verified to match infinigence/llama.cpp reference implementation. Note: Use --no-warmup flag to avoid warmup memory allocation issue. --- src/llama-arch.cpp | 56 +++++++++++---------------- src/llama-arch.h | 1 + src/llama-graph.cpp | 94 +++++++++++++++++++++++++++++++++++++++++++++ src/llama-graph.h | 13 +++++++ src/llama-model.cpp | 74 +++++++++++++++++++++++++++++++++++ 5 files changed, 205 insertions(+), 33 deletions(-) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index b7642b568..ea872d1c9 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -108,6 +108,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_MINIMAX_M2, "minimax-m2" }, { LLM_ARCH_COGVLM, "cogvlm" }, { LLM_ARCH_PANGU_EMBED, "pangu-embedded" }, + { LLM_ARCH_MEGREZ_MOE, "megrez-moe" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -2379,40 +2380,29 @@ static const std::map> LLM_TENSOR_N }, }, { - LLM_ARCH_PANGU_EMBED, + LLM_ARCH_MEGREZ_MOE, { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - }, - }, - { - LLM_ARCH_COGVLM, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_VISEXP_ATTN_QKV, "blk.%d.vis_attn_qkv" }, - { LLM_TENSOR_VISEXP_ATTN_OUT, "blk.%d.vis_attn_output" }, - { LLM_TENSOR_VISEXP_FFN_GATE, "blk.%d.vis_gate" }, - { LLM_TENSOR_VISEXP_FFN_DOWN, "blk.%d.vis_down" }, - { LLM_TENSOR_VISEXP_FFN_UP, "blk.%d.vis_up" }, + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, +>>>>>>> 256414a18 (feat: Add Megrez-MoE architecture support) }, }, { diff --git a/src/llama-arch.h b/src/llama-arch.h index a769dd1e8..28719e62a 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -112,6 +112,7 @@ enum llm_arch { LLM_ARCH_MINIMAX_M2, LLM_ARCH_COGVLM, LLM_ARCH_PANGU_EMBED, + LLM_ARCH_MEGREZ_MOE, LLM_ARCH_UNKNOWN, }; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index f9751b318..19ec05565 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1140,6 +1140,100 @@ ggml_tensor * llm_graph_context::build_moe_ffn( return moe_out; } +ggml_tensor * llm_graph_context::build_mergez_moe_ffn( + ggml_tensor * cur, + ggml_tensor * hidden_state, + ggml_tensor * gate_inp, + ggml_tensor * exp_probs_b, + ggml_tensor * up_exps, + ggml_tensor * gate_exps, + ggml_tensor * down_exps, + int64_t n_expert, + int64_t n_expert_used, + int il) const { + const int64_t n_embd = cur->ne[0]; + const int64_t n_tokens = cur->ne[1]; + + ggml_tensor * logits = build_lora_mm(gate_inp, hidden_state); // [n_expert, n_tokens] + cb(logits, "ffn_moe_logits", il); + + ggml_tensor * normalized_logits = nullptr; + ggml_tensor * probs = nullptr; + if (exp_probs_b) { + // For Megrez: sigmoid THEN add bias (not the other way around!) + normalized_logits = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens] + cb(normalized_logits, "ffn_moe_logits_normalize", il); + probs = ggml_add(ctx0, normalized_logits, exp_probs_b); // Add bias AFTER sigmoid + cb(probs, "ffn_moe_probs", il); + } else { + probs = ggml_soft_max(ctx0, logits); // [n_expert, n_tokens] + } + + // select experts + ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_expert_used, n_tokens] + cb(selected_experts->src[0], "ffn_moe_argsort", il); + cb(selected_experts, "ffn_moe_topk", il); + + ggml_tensor * weights = nullptr; + if (exp_probs_b) { + ggml_tensor * weight0s = ggml_get_rows(ctx0, + ggml_reshape_3d(ctx0, normalized_logits, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] + cb(weight0s, "ffn_moe_weights0", il); + weight0s = ggml_reshape_2d(ctx0, weight0s, n_expert_used, n_tokens); + ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weight0s); // [1, n_tokens] + cb(weights_sum, "ffn_moe_weights0_sum", il); + weights = ggml_div(ctx0, weight0s, weights_sum); // [n_expert_used, n_tokens] + cb(weights, "ffn_moe_weights_norm", il); + weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens); + } else { + weights = ggml_get_rows(ctx0, + ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] + cb(weights, "ffn_moe_weights", il); + } + + cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens); + + ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + cb(up, "ffn_moe_up", il); + + ggml_tensor * gate = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + cb(gate, "ffn_moe_gate", il); + + gate = ggml_silu(ctx0, gate); + cb(gate, "ffn_moe_silu", il); + + ggml_tensor * par = ggml_mul(ctx0, up, gate); // [n_ff, n_expert_used, n_tokens] + cb(par, "ffn_moe_gate_par", il); + + ggml_tensor * experts = build_lora_mm_id(down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens] + cb(experts, "ffn_moe_down", il); + + experts = ggml_mul(ctx0, experts, weights); + cb(experts, "ffn_moe_weighted", il); + + // aggregate experts + ggml_tensor * moe_out = nullptr; + for (int i = 0; i < n_expert_used; ++i) { + ggml_tensor * cur_expert = ggml_view_2d(ctx0, experts, n_embd, n_tokens, + experts->nb[2], i*experts->nb[1]); + + if (i == 0) { + moe_out = cur_expert; + } else { + moe_out = ggml_add(ctx0, moe_out, cur_expert); + } + } + + if (n_expert_used == 1) { + // avoid returning a non-contiguous tensor + moe_out = ggml_cont(ctx0, moe_out); + } + + cb(moe_out, "ffn_moe_out", il); + + return moe_out; +} + // input embeddings with optional lora ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { const int64_t n_embd = hparams.n_embd; diff --git a/src/llama-graph.h b/src/llama-graph.h index d0c3934f6..a1c88ae5d 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -672,6 +672,19 @@ struct llm_graph_context { int il, ggml_tensor * probs_in = nullptr) const; + // build Megrez MoE FFN (special gating with sigmoid + bias) + ggml_tensor * build_mergez_moe_ffn( + ggml_tensor * cur, + ggml_tensor * hidden_state, + ggml_tensor * gate_inp, + ggml_tensor * exp_probs_b, + ggml_tensor * up_exps, + ggml_tensor * gate_exps, + ggml_tensor * down_exps, + int64_t n_expert, + int64_t n_expert_used, + int il) const; + // // inputs // diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 1987135ca..b750e00bd 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2183,6 +2183,16 @@ void llama_model::load_hparams(llama_model_loader & ml) { switch (hparams.n_layer) { case 26: type = LLM_TYPE_1B; break; // openPangu-Embedded-1B-V1.1 case 34: type = LLM_TYPE_7B; break; // openPangu-Embedded-7B-V1.1 + case LLM_ARCH_MEGREZ_MOE: + { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + + switch (hparams.n_layer) { + case 31: type = LLM_TYPE_7B; break; default: type = LLM_TYPE_UNKNOWN; } } break; @@ -3338,6 +3348,65 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); } } break; + case LLM_ARCH_MEGREZ_MOE: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + // Layer 0 is dense, layers 1-30 are MoE + if (i == 0) { + // Dense layer + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } else { + // All MoE layers (1-30) have these + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, i), {n_expert}, 0); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0 for MEGREZ_MOE"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0 for MEGREZ_MOE"); + } + + // All MoE layers have shared expert + const int64_t n_ff_shexp = hparams.n_ff_shexp; + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0); + + // Only layers 1, 4, 7, 10, 13, 16, 19, 22, 25, 28 have actual expert tensors + // Pattern: (i-1) % 3 == 0 for i > 0 + if ((i - 1) % 3 == 0) { + // MoE branch - use the expert-specific FF size from hparams + const int64_t n_ff_exp = hparams.n_ff_exp; + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + } + // Note: layers that share experts (2, 3, 5, 6, etc.) only have gate_inp and shared expert + // They will reference the regular experts from their corresponding "full" layer during inference + } + } + } break; case LLM_ARCH_QWEN3: case LLM_ARCH_QWEN3VL: { @@ -7178,6 +7247,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_MEGREZ_MOE: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_NEMOTRON: { llm = std::make_unique(*this, params); @@ -7518,6 +7591,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_GPTNEOX: case LLM_ARCH_CODESHELL: case LLM_ARCH_ORION: + case LLM_ARCH_MEGREZ_MOE: case LLM_ARCH_NEMOTRON: case LLM_ARCH_EXAONE: case LLM_ARCH_EXAONE4: From d7443ba182d7cb4c2cfa740d255b697435dfe1b2 Mon Sep 17 00:00:00 2001 From: tamarPal Date: Thu, 6 Nov 2025 14:39:52 +0200 Subject: [PATCH 2/8] fix: increase graph nodes for Megrez-MoE warmup MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Megrez-MoE creates many intermediate tensors during MoE FFN construction: - sigmoid, add, reshape (3x), get_rows, sum_rows, div, view_2d, mul_mat operations - ggml_top_k internally calls ggml_argsort + ggml_view_4d (2 more tensors per layer) - Each of 30 MoE layers creates ~35 intermediate tensors during graph construction During warmup, the graph is built 3 times with different batch sizes, requiring sufficient memory pool space for all intermediate tensors. Add 4096 node overhead for LLM_ARCH_MEGREZ_MOE to accommodate these intermediate tensors (30 layers × 35 tensors/layer ≈ 1050 nodes, doubled for safety margin). This fixes the 'not enough space in the context's memory pool' error during warmup, allowing Megrez-MoE to work without the --no-warmup flag. Tested: - All 39 tests pass - Megrez-MoE works with warmup enabled (no crashes) - Other models (e.g., Gemma-2) are unaffected - Verified with outputs up to 100 tokens --- src/llama-context.cpp | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 866514038..a87b5a0fb 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1382,7 +1382,21 @@ void llama_context::output_reorder() { // uint32_t llama_context::graph_max_nodes() const { - return std::max(1024u, 8u*model.n_tensors()); + uint32_t base_nodes = std::max(1024u, 8u*model.n_tensors()); + + // Megrez-MoE creates many intermediate tensors in build_mergez_moe_ffn for each layer: + // - sigmoid, add (bias), reshape (3x), get_rows, sum_rows, div, view_2d, mul_mat (per expert) + // - ggml_top_k internally calls ggml_argsort + ggml_view_4d (2 more tensors per layer) + // Each MoE layer needs ~30-40 intermediate tensors during graph construction + // With 30 MoE layers, this adds significant overhead to the graph (30 layers * 35 tensors = ~1050) + // During warmup, the graph is built 3 times with different batch sizes + if (model.arch == LLM_ARCH_MEGREZ_MOE) { + // Add substantial overhead: ~35 intermediate tensors per MoE layer * 30 layers = ~1050 nodes + // Double it to 4096 for safety margin during warmup's triple graph construction + base_nodes += 4096; + } + + return base_nodes; } llm_graph_result * llama_context::get_gf_res_reserve() const { From c00b18391a21084cb616e033eb5352271cabe210 Mon Sep 17 00:00:00 2001 From: tamarPal Date: Fri, 7 Nov 2025 00:30:20 +0200 Subject: [PATCH 3/8] feat: adapt Megrez-MoE to new models/*.cpp architecture - Move llm_build_megrez_moe from llama-model.cpp to src/models/megrez-moe.cpp - Add declaration to src/models/models.h - Update CMakeLists.txt to include megrez-moe.cpp in build - Resolve merge conflicts in llama-arch.cpp and llama-model.cpp - Fix PANGU_EMBED case statement closing braces The model loads successfully, all tests pass (40/40), and inference works correctly. --- src/CMakeLists.txt | 1 + src/llama-arch.cpp | 1 - src/llama-model.cpp | 4 + src/models/megrez-moe.cpp | 216 ++++++++++++++++++++++++++++++++++++++ src/models/models.h | 4 + 5 files changed, 225 insertions(+), 1 deletion(-) create mode 100644 src/models/megrez-moe.cpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 630b2cddf..84df832d1 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -89,6 +89,7 @@ add_library(llama models/mamba.cpp models/minicpm3.cpp models/minimax-m2.cpp + models/megrez-moe.cpp models/mpt.cpp models/nemotron-h.cpp models/nemotron.cpp diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index ea872d1c9..44a6ed3d2 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -2402,7 +2402,6 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, ->>>>>>> 256414a18 (feat: Add Megrez-MoE architecture support) }, }, { diff --git a/src/llama-model.cpp b/src/llama-model.cpp index b750e00bd..c5bb6289c 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2180,9 +2180,13 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_PANGU_EMBED: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { case 26: type = LLM_TYPE_1B; break; // openPangu-Embedded-1B-V1.1 case 34: type = LLM_TYPE_7B; break; // openPangu-Embedded-7B-V1.1 + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_MEGREZ_MOE: { ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); diff --git a/src/models/megrez-moe.cpp b/src/models/megrez-moe.cpp new file mode 100644 index 000000000..c1194277f --- /dev/null +++ b/src/models/megrez-moe.cpp @@ -0,0 +1,216 @@ +#include "models.h" + + + +llm_build_megrez_moe::llm_build_megrez_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params){ + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv(); + + const float kq_scale = 1.0f/sqrtf(float(n_embd_head)); + + ggml_tensor * pre_gate_hidden; + // Layer 0 + { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[0].attn_norm, NULL, + LLM_NORM_RMS, 0); + cb(cur, "attn_norm", 0); + + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[0].wq, cur); + cb(Qcur, "Qcur", 0); + + ggml_tensor * Kcur = build_lora_mm(model.layers[0].wk, cur); + cb(Kcur, "Kcur", 0); + + ggml_tensor * Vcur = build_lora_mm(model.layers[0].wv, cur); + cb(Vcur, "Vcur", 0); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", 0); + cb(Kcur, "Kcur", 0); + cb(Vcur, "Vcur", 0); + + cur = build_attn(inp_attn, + model.layers[0].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, 0); + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", 0); + + // feed-forward network + cur = build_norm(ffn_inp, + model.layers[0].ffn_norm, NULL, + LLM_NORM_RMS, 0); + cb(cur, "ffn_norm", 0); + + pre_gate_hidden = cur; + + cur = build_ffn(cur, + model.layers[0].ffn_up, NULL, NULL, + model.layers[0].ffn_gate, NULL, NULL, + model.layers[0].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, 0); + + cb(cur, "ffn_out", 0); + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out_add", 0); + + } + inpL = cur; + for (int il = 1; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(cur, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + pre_gate_hidden = ggml_get_rows(ctx0, pre_gate_hidden, inp_out_ids); + } + + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + if ((uint32_t) il < hparams.n_layer_dense_lead) { + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + // MoE branch + ggml_tensor * moe_out = build_mergez_moe_ffn(cur, + pre_gate_hidden, + model.layers[il].ffn_gate_inp, model.layers[il].ffn_exp_probs_b, + model.layers[((il - 1) / (3) * (3)) + 1].ffn_up_exps, + model.layers[((il - 1) / (3) * (3)) + 1].ffn_gate_exps, + model.layers[((il - 1) / (3) * (3)) + 1].ffn_down_exps, + n_expert, n_expert_used, + il); + cb(moe_out, "ffn_moe_out", il); + + pre_gate_hidden = cur; + + // FFN shared expert + { + ggml_tensor * ffn_shexp = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/src/models/models.h b/src/models/models.h index 2fffb382d..daf6707af 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -317,6 +317,10 @@ struct llm_build_minimax_m2 : public llm_graph_context { llm_build_minimax_m2(const llama_model & model, const llm_graph_params & params); }; +struct llm_build_megrez_moe : public llm_graph_context { + llm_build_megrez_moe(const llama_model & model, const llm_graph_params & params); +}; + struct llm_build_mpt : public llm_graph_context { llm_build_mpt(const llama_model & model, const llm_graph_params & params); }; From 053234f91dc5e6cd1356553b6e515b319376d2c1 Mon Sep 17 00:00:00 2001 From: tamarPal Date: Fri, 7 Nov 2025 00:52:26 +0200 Subject: [PATCH 4/8] refactor: use standard build_moe_ffn instead of custom build_mergez_moe_ffn - Remove custom build_mergez_moe_ffn implementation (100+ lines) - Use existing build_moe_ffn with LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID - Pre-compute gate logits from pre_gate_hidden (Megrez-MoE's unique gating) - Pass pre-computed logits via probs_in parameter - Maintain exact same behavior and output quality This addresses review feedback to reuse existing MoE infrastructure instead of duplicating code. The sigmoid gating + bias after activation is already supported by build_moe_ffn. --- src/llama-graph.cpp | 94 --------------------------------------- src/llama-graph.h | 13 ------ src/models/megrez-moe.cpp | 21 +++++++-- 3 files changed, 17 insertions(+), 111 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 19ec05565..f9751b318 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1140,100 +1140,6 @@ ggml_tensor * llm_graph_context::build_moe_ffn( return moe_out; } -ggml_tensor * llm_graph_context::build_mergez_moe_ffn( - ggml_tensor * cur, - ggml_tensor * hidden_state, - ggml_tensor * gate_inp, - ggml_tensor * exp_probs_b, - ggml_tensor * up_exps, - ggml_tensor * gate_exps, - ggml_tensor * down_exps, - int64_t n_expert, - int64_t n_expert_used, - int il) const { - const int64_t n_embd = cur->ne[0]; - const int64_t n_tokens = cur->ne[1]; - - ggml_tensor * logits = build_lora_mm(gate_inp, hidden_state); // [n_expert, n_tokens] - cb(logits, "ffn_moe_logits", il); - - ggml_tensor * normalized_logits = nullptr; - ggml_tensor * probs = nullptr; - if (exp_probs_b) { - // For Megrez: sigmoid THEN add bias (not the other way around!) - normalized_logits = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens] - cb(normalized_logits, "ffn_moe_logits_normalize", il); - probs = ggml_add(ctx0, normalized_logits, exp_probs_b); // Add bias AFTER sigmoid - cb(probs, "ffn_moe_probs", il); - } else { - probs = ggml_soft_max(ctx0, logits); // [n_expert, n_tokens] - } - - // select experts - ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_expert_used, n_tokens] - cb(selected_experts->src[0], "ffn_moe_argsort", il); - cb(selected_experts, "ffn_moe_topk", il); - - ggml_tensor * weights = nullptr; - if (exp_probs_b) { - ggml_tensor * weight0s = ggml_get_rows(ctx0, - ggml_reshape_3d(ctx0, normalized_logits, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] - cb(weight0s, "ffn_moe_weights0", il); - weight0s = ggml_reshape_2d(ctx0, weight0s, n_expert_used, n_tokens); - ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weight0s); // [1, n_tokens] - cb(weights_sum, "ffn_moe_weights0_sum", il); - weights = ggml_div(ctx0, weight0s, weights_sum); // [n_expert_used, n_tokens] - cb(weights, "ffn_moe_weights_norm", il); - weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens); - } else { - weights = ggml_get_rows(ctx0, - ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] - cb(weights, "ffn_moe_weights", il); - } - - cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens); - - ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] - cb(up, "ffn_moe_up", il); - - ggml_tensor * gate = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] - cb(gate, "ffn_moe_gate", il); - - gate = ggml_silu(ctx0, gate); - cb(gate, "ffn_moe_silu", il); - - ggml_tensor * par = ggml_mul(ctx0, up, gate); // [n_ff, n_expert_used, n_tokens] - cb(par, "ffn_moe_gate_par", il); - - ggml_tensor * experts = build_lora_mm_id(down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens] - cb(experts, "ffn_moe_down", il); - - experts = ggml_mul(ctx0, experts, weights); - cb(experts, "ffn_moe_weighted", il); - - // aggregate experts - ggml_tensor * moe_out = nullptr; - for (int i = 0; i < n_expert_used; ++i) { - ggml_tensor * cur_expert = ggml_view_2d(ctx0, experts, n_embd, n_tokens, - experts->nb[2], i*experts->nb[1]); - - if (i == 0) { - moe_out = cur_expert; - } else { - moe_out = ggml_add(ctx0, moe_out, cur_expert); - } - } - - if (n_expert_used == 1) { - // avoid returning a non-contiguous tensor - moe_out = ggml_cont(ctx0, moe_out); - } - - cb(moe_out, "ffn_moe_out", il); - - return moe_out; -} - // input embeddings with optional lora ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { const int64_t n_embd = hparams.n_embd; diff --git a/src/llama-graph.h b/src/llama-graph.h index a1c88ae5d..d0c3934f6 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -672,19 +672,6 @@ struct llm_graph_context { int il, ggml_tensor * probs_in = nullptr) const; - // build Megrez MoE FFN (special gating with sigmoid + bias) - ggml_tensor * build_mergez_moe_ffn( - ggml_tensor * cur, - ggml_tensor * hidden_state, - ggml_tensor * gate_inp, - ggml_tensor * exp_probs_b, - ggml_tensor * up_exps, - ggml_tensor * gate_exps, - ggml_tensor * down_exps, - int64_t n_expert, - int64_t n_expert_used, - int il) const; - // // inputs // diff --git a/src/models/megrez-moe.cpp b/src/models/megrez-moe.cpp index c1194277f..699c9f6ac 100644 --- a/src/models/megrez-moe.cpp +++ b/src/models/megrez-moe.cpp @@ -163,14 +163,27 @@ llm_build_megrez_moe::llm_build_megrez_moe(const llama_model & model, const llm_ cb(cur, "ffn_out", il); } else { // MoE branch - ggml_tensor * moe_out = build_mergez_moe_ffn(cur, - pre_gate_hidden, - model.layers[il].ffn_gate_inp, model.layers[il].ffn_exp_probs_b, + // Note: Megrez-MoE uses pre_gate_hidden (from previous layer's FFN norm) for gating + // This is different from standard MoE which uses current layer's input + // Compute gate logits from pre_gate_hidden instead of cur + ggml_tensor * gate_logits = build_lora_mm(model.layers[il].ffn_gate_inp, pre_gate_hidden); + cb(gate_logits, "ffn_moe_logits", il); + + // Use standard build_moe_ffn but with pre-computed gate logits + ggml_tensor * moe_out = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, model.layers[((il - 1) / (3) * (3)) + 1].ffn_up_exps, model.layers[((il - 1) / (3) * (3)) + 1].ffn_gate_exps, model.layers[((il - 1) / (3) * (3)) + 1].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, - il); + LLM_FFN_SILU, + true, // norm_w + false, // scale_w + 1.0f, // w_scale + LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID, + il, + gate_logits); // Use pre-computed logits from pre_gate_hidden cb(moe_out, "ffn_moe_out", il); pre_gate_hidden = cur; From 90cd13d840a2ce5f5b8223b3c6726e161cf8289f Mon Sep 17 00:00:00 2001 From: tamarPal Date: Fri, 7 Nov 2025 01:04:48 +0200 Subject: [PATCH 5/8] fix: remove trailing whitespace --- src/llama-context.cpp | 5 +++-- src/models/megrez-moe.cpp | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index a87b5a0fb..77da64b54 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1383,7 +1383,7 @@ void llama_context::output_reorder() { uint32_t llama_context::graph_max_nodes() const { uint32_t base_nodes = std::max(1024u, 8u*model.n_tensors()); - + // Megrez-MoE creates many intermediate tensors in build_mergez_moe_ffn for each layer: // - sigmoid, add (bias), reshape (3x), get_rows, sum_rows, div, view_2d, mul_mat (per expert) // - ggml_top_k internally calls ggml_argsort + ggml_view_4d (2 more tensors per layer) @@ -1395,7 +1395,8 @@ uint32_t llama_context::graph_max_nodes() const { // Double it to 4096 for safety margin during warmup's triple graph construction base_nodes += 4096; } - + + return base_nodes; } diff --git a/src/models/megrez-moe.cpp b/src/models/megrez-moe.cpp index 699c9f6ac..20fe59033 100644 --- a/src/models/megrez-moe.cpp +++ b/src/models/megrez-moe.cpp @@ -168,7 +168,7 @@ llm_build_megrez_moe::llm_build_megrez_moe(const llama_model & model, const llm_ // Compute gate logits from pre_gate_hidden instead of cur ggml_tensor * gate_logits = build_lora_mm(model.layers[il].ffn_gate_inp, pre_gate_hidden); cb(gate_logits, "ffn_moe_logits", il); - + // Use standard build_moe_ffn but with pre-computed gate logits ggml_tensor * moe_out = build_moe_ffn(cur, model.layers[il].ffn_gate_inp, From 2d2e419d076c5eaabf48717e55afba2d4ddab132 Mon Sep 17 00:00:00 2001 From: tamarPal Date: Sun, 9 Nov 2025 13:13:05 +0200 Subject: [PATCH 6/8] fix: resolve additional merge issues from rebase - Restore PANGU_EMBED and COGVLM tensor mappings in llama-arch.cpp - Remove extra blank line in llama-context.cpp --- src/llama-arch.cpp | 37 +++++++++++++++++++++++++++++++++++++ src/llama-context.cpp | 1 - 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 44a6ed3d2..508c82ef9 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -2404,6 +2404,43 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_PANGU_EMBED, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_COGVLM, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_VISEXP_ATTN_QKV, "blk.%d.vis_attn_qkv" }, + { LLM_TENSOR_VISEXP_ATTN_OUT, "blk.%d.vis_attn_output" }, + { LLM_TENSOR_VISEXP_FFN_GATE, "blk.%d.vis_gate" }, + { LLM_TENSOR_VISEXP_FFN_DOWN, "blk.%d.vis_down" }, + { LLM_TENSOR_VISEXP_FFN_UP, "blk.%d.vis_up" }, + }, + }, { LLM_ARCH_UNKNOWN, { diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 77da64b54..41bcba308 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1396,7 +1396,6 @@ uint32_t llama_context::graph_max_nodes() const { base_nodes += 4096; } - return base_nodes; } From adb7fdd69f28f5ce75cf926dd7cf568a910555a5 Mon Sep 17 00:00:00 2001 From: tamarPal Date: Sun, 9 Nov 2025 15:58:13 +0200 Subject: [PATCH 7/8] Add Megrez-MoE GGUF conversion and inference support --- convert_hf_to_gguf.py | 128 +++++++++++++++++++++++++++++++++++++- gguf-py/gguf/constants.py | 2 + 2 files changed, 128 insertions(+), 2 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 222f6ed6d..71fb4ab88 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -8957,8 +8957,132 @@ def set_gguf_parameters(self): self.gguf_writer.add_rope_freq_base(self.find_hparam(["rope_theta"])) -@ModelBase.register("HunYuanMoEV1ForCausalLM") -class HunYuanMoEModel(TextModel): +@ModelBase.register("MegrezMoEForCausalLM") +class MegrezMoEModel(TextModel): + model_arch = gguf.MODEL_ARCH.MEGREZ_MOE + + def set_vocab(self): + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True) + + tokpre = self.get_vocab_base_pre(tokenizer) + merges = [] + vocab = {} + mergeable_ranks = getattr(tokenizer, "mergeable_ranks", {}) + for token, rank in mergeable_ranks.items(): + vocab[QwenModel.token_bytes_to_string(token)] = rank + if len(token) == 1: + continue + merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank) + if len(merged) == 2: + merges.append(' '.join(map(QwenModel.token_bytes_to_string, merged))) + + vocab_size = self.hparams["vocab_size"] + assert tokenizer.vocab_size == vocab_size + special_tokens = getattr(tokenizer, "special_tokens", {}) + reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in {**vocab, **special_tokens}.items()} + tokens: list[str] = [] + toktypes: list[int] = [] + for i in range(vocab_size): + if i not in reverse_vocab: + tokens.append(f"[PAD{i}]") + toktypes.append(gguf.TokenType.UNUSED) + else: + token = reverse_vocab[i] + tokens.append(token) + if i in special_tokens.values(): + toktypes.append(gguf.TokenType.CONTROL) + else: + toktypes.append(gguf.TokenType.NORMAL) + + self.gguf_writer.add_tokenizer_model("gpt2") + self.gguf_writer.add_tokenizer_pre(tokpre) + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_types(toktypes) + self.gguf_writer.add_token_merges(merges) + + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False) + special_vocab.add_to_gguf(self.gguf_writer) + # BOS token fix if needed + # self.gguf_writer.add_bos_token_id() + + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + + self.gguf_writer.add_expert_count(hparams["num_experts"]) + self.gguf_writer.add_expert_shared_feed_forward_length(hparams["intermediate_size"]) + + moe_intermediate_size = hparams["moe_intermediate_size"] + assert all(n == moe_intermediate_size[0] for n in moe_intermediate_size) + self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size[0]) + + moe_topk = hparams["moe_topk"] + assert all(topk == moe_topk[0] for topk in moe_topk) + self.gguf_writer.add_expert_used_count(moe_topk[0]) + + moe_shared_expert = hparams["num_shared_expert"] + assert all(n == moe_shared_expert[0] for n in moe_shared_expert) + self.gguf_writer.add_expert_shared_count(moe_shared_expert[0]) + + rope_scaling = hparams.get("rope_scaling", {}) + if rope_scaling.get("type") == "dynamic": + alpha = rope_scaling.get("alpha", 1000) + base = hparams.get("rope_theta", 10000.0) + dim = (hparams["hidden_size"] // hparams["num_attention_heads"]) + scaled_base = base * (alpha ** (dim / (dim - 2))) + self.gguf_writer.add_rope_freq_base(scaled_base) + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) + self.gguf_writer.add_rope_scaling_factor(1) + self.gguf_writer.add_rope_scaling_orig_ctx_len(256 * 1024) + self.gguf_writer.add_context_length(256 * 1024) + assert alpha == 1000 and base == 10000.0 and dim == 128 and self.hparams["max_position_embeddings"] in [32 * 1024, 256 * 1024], \ + "Megrez dynamic RoPE scaling assumptions changed, please update the logic or context length manually" + + _experts: list[dict[str, Tensor]] | None = None + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if name == "lm_head.weight": + if self.hparams.get("tie_word_embeddings", False): + logger.info("Skipping tied output layer 'lm_head.weight'") + return [] + + if name.find("mlp.experts") != -1: + n_experts = self.hparams["num_experts"] + assert bid is not None + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 3: + tensors: list[tuple[str, Tensor]] = [] + for w_name in ["down_proj", "gate_proj", "up_proj"]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + data_torch = torch.stack(datas, dim=0) + merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" + new_name = self.map_tensor_name(merged_name) + tensors.append((new_name, data_torch)) + + return tensors + else: + return [] + + return [(self.map_tensor_name(name), data_torch)] + + def prepare_tensors(self): + super().prepare_tensors() + if self._experts is not None: + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") model_arch = gguf.MODEL_ARCH.HUNYUAN_MOE def set_vocab(self): diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 6b4b6c5ab..94d4cecc9 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -427,6 +427,7 @@ class MODEL_ARCH(IntEnum): COGVLM = auto() MINIMAXM2 = auto() PANGU_EMBED = auto() + MEGREZ_MOE = auto() class VISION_PROJECTOR_TYPE(IntEnum): @@ -795,6 +796,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.MINIMAXM2: "minimax-m2", MODEL_ARCH.COGVLM: "cogvlm", MODEL_ARCH.PANGU_EMBED: "pangu-embedded", + MODEL_ARCH.MEGREZ_MOE: "megrez-moe", } VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = { From 810020d6f64c32f43be3267f13141986388153d8 Mon Sep 17 00:00:00 2001 From: tamarPal Date: Sun, 9 Nov 2025 16:29:35 +0200 Subject: [PATCH 8/8] Fix type errors in MegrezMoEModel only (Pyright compliance) --- convert_hf_to_gguf.py | 42 +++++++++++++++++++++++++++--------------- 1 file changed, 27 insertions(+), 15 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 71fb4ab88..4169f1aa6 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -9014,30 +9014,37 @@ def set_gguf_parameters(self): self.gguf_writer.add_expert_shared_feed_forward_length(hparams["intermediate_size"]) moe_intermediate_size = hparams["moe_intermediate_size"] - assert all(n == moe_intermediate_size[0] for n in moe_intermediate_size) - self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size[0]) + if moe_intermediate_size is not None and isinstance(moe_intermediate_size, (list, tuple)) and len(moe_intermediate_size) > 0: + assert all(n == moe_intermediate_size[0] for n in moe_intermediate_size) + self.gguf_writer.add_expert_feed_forward_length(int(moe_intermediate_size[0])) moe_topk = hparams["moe_topk"] - assert all(topk == moe_topk[0] for topk in moe_topk) - self.gguf_writer.add_expert_used_count(moe_topk[0]) + if moe_topk is not None and isinstance(moe_topk, (list, tuple)) and len(moe_topk) > 0: + assert all(topk == moe_topk[0] for topk in moe_topk) + self.gguf_writer.add_expert_used_count(int(moe_topk[0])) moe_shared_expert = hparams["num_shared_expert"] - assert all(n == moe_shared_expert[0] for n in moe_shared_expert) - self.gguf_writer.add_expert_shared_count(moe_shared_expert[0]) + if moe_shared_expert is not None and isinstance(moe_shared_expert, (list, tuple)) and len(moe_shared_expert) > 0: + assert all(n == moe_shared_expert[0] for n in moe_shared_expert) + self.gguf_writer.add_expert_shared_count(int(moe_shared_expert[0])) rope_scaling = hparams.get("rope_scaling", {}) if rope_scaling.get("type") == "dynamic": alpha = rope_scaling.get("alpha", 1000) base = hparams.get("rope_theta", 10000.0) - dim = (hparams["hidden_size"] // hparams["num_attention_heads"]) - scaled_base = base * (alpha ** (dim / (dim - 2))) - self.gguf_writer.add_rope_freq_base(scaled_base) - self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) - self.gguf_writer.add_rope_scaling_factor(1) - self.gguf_writer.add_rope_scaling_orig_ctx_len(256 * 1024) - self.gguf_writer.add_context_length(256 * 1024) - assert alpha == 1000 and base == 10000.0 and dim == 128 and self.hparams["max_position_embeddings"] in [32 * 1024, 256 * 1024], \ - "Megrez dynamic RoPE scaling assumptions changed, please update the logic or context length manually" + hidden_size = hparams.get("hidden_size") + num_attention_heads = hparams.get("num_attention_heads") + max_position_embeddings = self.hparams.get("max_position_embeddings") + if None not in (hidden_size, num_attention_heads, max_position_embeddings): + dim = hidden_size // num_attention_heads + scaled_base = base * (alpha ** (dim / (dim - 2))) + self.gguf_writer.add_rope_freq_base(scaled_base) + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) + self.gguf_writer.add_rope_scaling_factor(1) + self.gguf_writer.add_rope_scaling_orig_ctx_len(256 * 1024) + self.gguf_writer.add_context_length(256 * 1024) + assert alpha == 1000 and base == 10000.0 and dim == 128 and max_position_embeddings in [32 * 1024, 256 * 1024], \ + "Megrez dynamic RoPE scaling assumptions changed, please update the logic or context length manually" _experts: list[dict[str, Tensor]] | None = None @@ -9083,6 +9090,11 @@ def prepare_tensors(self): experts = [k for d in self._experts for k in d.keys()] if len(experts) > 0: raise ValueError(f"Unprocessed experts: {experts}") + +# ...existing code... + +@ModelBase.register("HunYuanMoEV1ForCausalLM") +class HunYuanMoEModel(TextModel): model_arch = gguf.MODEL_ARCH.HUNYUAN_MOE def set_vocab(self):