diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 13448fd68116c..821517f4b1c6c 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -9033,6 +9033,141 @@ def set_gguf_parameters(self): self.gguf_writer.add_rope_freq_base(self.find_hparam(["rope_theta"])) +@ModelBase.register("MegrezMoeForCausalLM", "MegrezMoEForCausalLM") +class MegrezMoEModel(TextModel): + model_arch = gguf.MODEL_ARCH.MEGREZ_MOE + + def set_vocab(self): + # Megrez-MoE uses Qwen-style BPE tokenizer + # Use standard GPT2 vocab loading which handles BPE correctly + try: + self._set_vocab_gpt2() + except Exception: + # Fallback to Qwen-specific handling if needed + self._set_vocab_qwen() + # Note: special_vocab.add_to_gguf() is already called within + # _set_vocab_gpt2() and _set_vocab_qwen(), so no need to call it again + + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + + # MoE expert configuration + num_experts = hparams.get("num_experts") or hparams.get("n_routed_experts") + if num_experts is None: + raise ValueError("Missing 'num_experts' or 'n_routed_experts' in model config") + self.gguf_writer.add_expert_count(num_experts) + + # Shared expert FFN size + hidden_size = hparams.get("hidden_size", 2048) + shared_expert_ffn_size = int(hidden_size * 2.75) + self.gguf_writer.add_expert_shared_feed_forward_length(shared_expert_ffn_size) + + # Per-expert FFN size (should be consistent across all experts) + moe_intermediate_size = hparams.get("moe_intermediate_size") + if moe_intermediate_size is None: + raise ValueError("Missing 'moe_intermediate_size' in model config") + if not isinstance(moe_intermediate_size, list): + moe_intermediate_size = [moe_intermediate_size] + + # Validate all experts have same size + if not all(n == moe_intermediate_size[0] for n in moe_intermediate_size): + raise ValueError(f"All experts must have same FFN size, got: {moe_intermediate_size}") + self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size[0]) + + # Shared expert count + num_shared_expert = hparams.get("num_shared_expert") or hparams.get("n_shared_experts") + if num_shared_expert is None: + raise ValueError("Missing 'num_shared_expert' or 'n_shared_experts' in model config") + if not isinstance(num_shared_expert, list): + num_shared_expert = [num_shared_expert] + + if not all(n == num_shared_expert[0] for n in num_shared_expert): + raise ValueError(f"All layers must have same shared expert count, got: {num_shared_expert}") + self.gguf_writer.add_expert_shared_count(num_shared_expert[0]) + + # RoPE scaling (Megrez may use dynamic scaling) + rope_scaling = hparams.get("rope_scaling") + if rope_scaling and rope_scaling.get("type") == "dynamic": + alpha = rope_scaling.get("alpha", 1000) + base = hparams.get("rope_theta", 10000.0) + hidden_size = hparams.get("hidden_size") + num_attention_heads = hparams.get("num_attention_heads") + + if hidden_size is None or num_attention_heads is None: + raise ValueError("Missing 'hidden_size' or 'num_attention_heads' for RoPE scaling") + + dim = hidden_size // num_attention_heads + scaled_base = base * (alpha ** (dim / (dim - 2))) + + self.gguf_writer.add_rope_freq_base(scaled_base) + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) + self.gguf_writer.add_rope_scaling_factor(1) + self.gguf_writer.add_rope_scaling_orig_ctx_len(256 * 1024) + self.gguf_writer.add_context_length(256 * 1024) + + _experts: list[dict[str, Tensor]] | None = None + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if name == "lm_head.weight": + if self.hparams.get("tie_word_embeddings", False): + logger.info("Skipping tied output layer 'lm_head.weight'") + return [] + + # Handle MoE gate bias (e_score_correction_bias) - map to exp_probs_b + if "e_score_correction_bias" in name: + # This is the expert selection bias - map to blk.N.exp_probs_b + # Format: model.layers.N.mlp.gate.e_score_correction_bias -> blk.N.exp_probs_b + layer_num = int(name.split(".")[2]) # Extract layer number + new_name = f"blk.{layer_num}.exp_probs_b" + return [(new_name, data_torch)] + + # Handle shared FFN (non-expert layers) - pass through directly + if name.find("mlp.down_proj") != -1 or name.find("mlp.gate_proj") != -1 or name.find("mlp.up_proj") != -1: + if name.find("mlp.experts") == -1: + # This is a shared FFN layer, not an expert - pass through + return [(self.map_tensor_name(name), data_torch)] + + if name.find("mlp.experts") != -1: + n_experts = self.hparams.get("num_experts") or self.hparams.get("n_routed_experts") + if n_experts is None: + raise ValueError("Missing 'num_experts' or 'n_routed_experts' in config") + assert bid is not None + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 3: + tensors: list[tuple[str, Tensor]] = [] + for w_name in ["down_proj", "gate_proj", "up_proj"]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + data_torch = torch.stack(datas, dim=0) + merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" + new_name = self.map_tensor_name(merged_name) + tensors.append((new_name, data_torch)) + + return tensors + else: + return [] + + return [(self.map_tensor_name(name), data_torch)] + + def prepare_tensors(self): + super().prepare_tensors() + if self._experts is not None: + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") + + @ModelBase.register("HunYuanMoEV1ForCausalLM") class HunYuanMoEModel(TextModel): model_arch = gguf.MODEL_ARCH.HUNYUAN_MOE diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 6b4b6c5ab075d..926958db2eb0a 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -427,6 +427,7 @@ class MODEL_ARCH(IntEnum): COGVLM = auto() MINIMAXM2 = auto() PANGU_EMBED = auto() + MEGREZ_MOE = auto() class VISION_PROJECTOR_TYPE(IntEnum): @@ -795,6 +796,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.MINIMAXM2: "minimax-m2", MODEL_ARCH.COGVLM: "cogvlm", MODEL_ARCH.PANGU_EMBED: "pangu-embedded", + MODEL_ARCH.MEGREZ_MOE: "megrez-moe", } VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = { @@ -1549,6 +1551,29 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_UP_EXP, ], + MODEL_ARCH.MEGREZ_MOE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_EXP_PROBS_B, + MODEL_TENSOR.FFN_GATE_INP_SHEXP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + ], MODEL_ARCH.QWEN3VL: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 630b2cddf67e8..84df832d1222c 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -89,6 +89,7 @@ add_library(llama models/mamba.cpp models/minicpm3.cpp models/minimax-m2.cpp + models/megrez-moe.cpp models/mpt.cpp models/nemotron-h.cpp models/nemotron.cpp diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index b7642b568dffb..508c82ef9a58a 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -108,6 +108,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_MINIMAX_M2, "minimax-m2" }, { LLM_ARCH_COGVLM, "cogvlm" }, { LLM_ARCH_PANGU_EMBED, "pangu-embedded" }, + { LLM_ARCH_MEGREZ_MOE, "megrez-moe" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -2378,6 +2379,31 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, }, }, + { + LLM_ARCH_MEGREZ_MOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_PANGU_EMBED, { diff --git a/src/llama-arch.h b/src/llama-arch.h index a769dd1e85741..28719e62ae8f0 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -112,6 +112,7 @@ enum llm_arch { LLM_ARCH_MINIMAX_M2, LLM_ARCH_COGVLM, LLM_ARCH_PANGU_EMBED, + LLM_ARCH_MEGREZ_MOE, LLM_ARCH_UNKNOWN, }; diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 70a3ec62dfc63..6e1cd4e4e1918 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1386,7 +1386,10 @@ void llama_context::output_reorder() { // uint32_t llama_context::graph_max_nodes() const { - return std::max(1024u, 8u*model.n_tensors()); + // Megrez-MoE creates many intermediate tensors (~35 per MoE layer) + // Use higher factor (9u) instead of base 8u to account for this overhead + uint32_t factor = (model.arch == LLM_ARCH_MEGREZ_MOE) ? 9u : 8u; + return std::max(1024u, factor * model.n_tensors()); } llm_graph_result * llama_context::get_gf_res_reserve() const { diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 829f1e3c14f82..9fd4b066fa171 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2174,12 +2174,26 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_PANGU_EMBED: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { case 26: type = LLM_TYPE_1B; break; // openPangu-Embedded-1B-V1.1 case 34: type = LLM_TYPE_7B; break; // openPangu-Embedded-7B-V1.1 default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_MEGREZ_MOE: + { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + + switch (hparams.n_layer) { + case 31: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; default: throw std::runtime_error("unsupported model architecture"); } @@ -2631,9 +2645,12 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_gate_exps = + create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); + layer.ffn_down_exps = + create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); + layer.ffn_up_exps = + create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); } } break; case LLM_ARCH_LLAMA4: @@ -3319,9 +3336,12 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // MoE branch const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_gate_exps = + create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); + layer.ffn_down_exps = + create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); + layer.ffn_up_exps = + create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); // Shared expert branch const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; @@ -3332,6 +3352,62 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); } } break; + case LLM_ARCH_MEGREZ_MOE: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (i == 0) { + // dense FFN + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } else { + // MoE FFN + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, i), {n_expert}, 0); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0 for MEGREZ_MOE"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0 for MEGREZ_MOE"); + } + + // Shared expert branch + const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; + + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0); + + // Routed expert branch + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), + { n_embd, n_ff_exp, n_expert }, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), + { n_ff_exp, n_embd, n_expert }, TENSOR_NOT_REQUIRED); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), + { n_embd, n_ff_exp, n_expert }, TENSOR_NOT_REQUIRED); + } + } + } break; case LLM_ARCH_QWEN3: case LLM_ARCH_QWEN3VL: { @@ -7165,6 +7241,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_MEGREZ_MOE: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_NEMOTRON: { llm = std::make_unique(*this, params); @@ -7509,6 +7589,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_GPTNEOX: case LLM_ARCH_CODESHELL: case LLM_ARCH_ORION: + case LLM_ARCH_MEGREZ_MOE: case LLM_ARCH_NEMOTRON: case LLM_ARCH_EXAONE: case LLM_ARCH_EXAONE4: diff --git a/src/models/megrez-moe.cpp b/src/models/megrez-moe.cpp new file mode 100644 index 0000000000000..854a685de0ed8 --- /dev/null +++ b/src/models/megrez-moe.cpp @@ -0,0 +1,156 @@ +#include "models.h" + +llm_build_megrez_moe::llm_build_megrez_moe(const llama_model & model, const llm_graph_params & params) : + llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv(); + + const float kq_scale = 1.0f/sqrtf(float(n_embd_head)); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + ggml_tensor * pre_gate_hidden = nullptr; + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + } + + if (il == n_layer - 1 && inp_out_ids) { + // skip computing output for unused tokens + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + if (pre_gate_hidden) { + pre_gate_hidden = ggml_get_rows(ctx0, pre_gate_hidden, inp_out_ids); + } + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + if (il == 0) { + // save for MoE gating + pre_gate_hidden = cur; + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + // MoE with pre-computed gate logits + ggml_tensor * gate_logits = build_lora_mm(model.layers[il].ffn_gate_inp, pre_gate_hidden); + cb(gate_logits, "ffn_moe_logits", il); + + // Expert sharing: layers 1,2,3 share experts from layer 1; 4,5,6 from layer 4, etc. + const int expert_layer_stride = 3; + const int expert_layer = ((il - 1) / expert_layer_stride) * expert_layer_stride + 1; + + // MoE FFN with pre-computed gate logits + ggml_tensor * moe_out = + build_moe_ffn(cur, model.layers[il].ffn_gate_inp, model.layers[expert_layer].ffn_up_exps, + model.layers[expert_layer].ffn_gate_exps, model.layers[expert_layer].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, LLM_FFN_SILU, true, false, + 1.0f, LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID, il, gate_logits); + cb(moe_out, "ffn_moe_out", il); + + pre_gate_hidden = cur; + + // FFN shared expert + { + ggml_tensor * ffn_shexp = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/src/models/models.h b/src/models/models.h index 2fffb382df2e5..daf6707afc280 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -317,6 +317,10 @@ struct llm_build_minimax_m2 : public llm_graph_context { llm_build_minimax_m2(const llama_model & model, const llm_graph_params & params); }; +struct llm_build_megrez_moe : public llm_graph_context { + llm_build_megrez_moe(const llama_model & model, const llm_graph_params & params); +}; + struct llm_build_mpt : public llm_graph_context { llm_build_mpt(const llama_model & model, const llm_graph_params & params); };