@@ -2183,6 +2183,16 @@ void llama_model::load_hparams(llama_model_loader & ml) {
21832183 switch (hparams.n_layer ) {
21842184 case 26 : type = LLM_TYPE_1B; break ; // openPangu-Embedded-1B-V1.1
21852185 case 34 : type = LLM_TYPE_7B; break ; // openPangu-Embedded-7B-V1.1
2186+ case LLM_ARCH_MEGREZ_MOE:
2187+ {
2188+ ml.get_key (LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp );
2189+ ml.get_key (LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared );
2190+ ml.get_key (LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp );
2191+ ml.get_key (LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps );
2192+ ml.get_key (LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func , false );
2193+
2194+ switch (hparams.n_layer ) {
2195+ case 31 : type = LLM_TYPE_7B; break ;
21862196 default : type = LLM_TYPE_UNKNOWN;
21872197 }
21882198 } break ;
@@ -3338,6 +3348,65 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
33383348 layer.ffn_up_shexp = create_tensor (tn (LLM_TENSOR_FFN_UP_SHEXP, " weight" , i), { n_embd, n_ff_shexp}, 0 );
33393349 }
33403350 } break ;
3351+ case LLM_ARCH_MEGREZ_MOE:
3352+ {
3353+ tok_embd = create_tensor (tn (LLM_TENSOR_TOKEN_EMBD, " weight" ), {n_embd, n_vocab}, 0 );
3354+
3355+ // output
3356+ output_norm = create_tensor (tn (LLM_TENSOR_OUTPUT_NORM, " weight" ), {n_embd}, 0 );
3357+ output = create_tensor (tn (LLM_TENSOR_OUTPUT, " weight" ), {n_embd, n_vocab}, 0 );
3358+
3359+ for (int i = 0 ; i < n_layer; ++i) {
3360+ auto & layer = layers[i];
3361+
3362+ layer.attn_norm = create_tensor (tn (LLM_TENSOR_ATTN_NORM, " weight" , i), {n_embd}, 0 );
3363+
3364+ layer.wq = create_tensor (tn (LLM_TENSOR_ATTN_Q, " weight" , i), {n_embd, n_embd}, 0 );
3365+ layer.wk = create_tensor (tn (LLM_TENSOR_ATTN_K, " weight" , i), {n_embd, n_embd_gqa}, 0 );
3366+ layer.wv = create_tensor (tn (LLM_TENSOR_ATTN_V, " weight" , i), {n_embd, n_embd_gqa}, 0 );
3367+ layer.wo = create_tensor (tn (LLM_TENSOR_ATTN_OUT, " weight" , i), {n_embd, n_embd}, 0 );
3368+
3369+ layer.ffn_norm = create_tensor (tn (LLM_TENSOR_FFN_NORM, " weight" , i), {n_embd}, 0 );
3370+
3371+ // Layer 0 is dense, layers 1-30 are MoE
3372+ if (i == 0 ) {
3373+ // Dense layer
3374+ layer.ffn_gate = create_tensor (tn (LLM_TENSOR_FFN_GATE, " weight" , i), {n_embd, n_ff}, 0 );
3375+ layer.ffn_down = create_tensor (tn (LLM_TENSOR_FFN_DOWN, " weight" , i), {n_ff, n_embd}, 0 );
3376+ layer.ffn_up = create_tensor (tn (LLM_TENSOR_FFN_UP, " weight" , i), {n_embd, n_ff}, 0 );
3377+ } else {
3378+ // All MoE layers (1-30) have these
3379+ layer.ffn_gate_inp = create_tensor (tn (LLM_TENSOR_FFN_GATE_INP, " weight" , i), {n_embd, n_expert}, 0 );
3380+ layer.ffn_exp_probs_b = create_tensor (tn (LLM_TENSOR_FFN_EXP_PROBS_B, i), {n_expert}, 0 );
3381+
3382+ if (n_expert == 0 ) {
3383+ throw std::runtime_error (" n_expert must be > 0 for MEGREZ_MOE" );
3384+ }
3385+ if (n_expert_used == 0 ) {
3386+ throw std::runtime_error (" n_expert_used must be > 0 for MEGREZ_MOE" );
3387+ }
3388+
3389+ // All MoE layers have shared expert
3390+ const int64_t n_ff_shexp = hparams.n_ff_shexp ;
3391+ layer.ffn_gate_shexp = create_tensor (tn (LLM_TENSOR_FFN_GATE_SHEXP, " weight" , i), {n_embd, n_ff_shexp}, 0 );
3392+ layer.ffn_down_shexp = create_tensor (tn (LLM_TENSOR_FFN_DOWN_SHEXP, " weight" , i), {n_ff_shexp, n_embd}, 0 );
3393+ layer.ffn_up_shexp = create_tensor (tn (LLM_TENSOR_FFN_UP_SHEXP, " weight" , i), {n_embd, n_ff_shexp}, 0 );
3394+
3395+ // Only layers 1, 4, 7, 10, 13, 16, 19, 22, 25, 28 have actual expert tensors
3396+ // Pattern: (i-1) % 3 == 0 for i > 0
3397+ if ((i - 1 ) % 3 == 0 ) {
3398+ // MoE branch - use the expert-specific FF size from hparams
3399+ const int64_t n_ff_exp = hparams.n_ff_exp ;
3400+
3401+ layer.ffn_gate_exps = create_tensor (tn (LLM_TENSOR_FFN_GATE_EXPS, " weight" , i), { n_embd, n_ff_exp, n_expert}, 0 );
3402+ layer.ffn_down_exps = create_tensor (tn (LLM_TENSOR_FFN_DOWN_EXPS, " weight" , i), {n_ff_exp, n_embd, n_expert}, 0 );
3403+ layer.ffn_up_exps = create_tensor (tn (LLM_TENSOR_FFN_UP_EXPS, " weight" , i), { n_embd, n_ff_exp, n_expert}, 0 );
3404+ }
3405+ // Note: layers that share experts (2, 3, 5, 6, etc.) only have gate_inp and shared expert
3406+ // They will reference the regular experts from their corresponding "full" layer during inference
3407+ }
3408+ }
3409+ } break ;
33413410 case LLM_ARCH_QWEN3:
33423411 case LLM_ARCH_QWEN3VL:
33433412 {
@@ -7178,6 +7247,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
71787247 {
71797248 llm = std::make_unique<llm_build_jais>(*this , params);
71807249 } break ;
7250+ case LLM_ARCH_MEGREZ_MOE:
7251+ {
7252+ llm = std::make_unique<llm_build_megrez_moe>(*this , params);
7253+ } break ;
71817254 case LLM_ARCH_NEMOTRON:
71827255 {
71837256 llm = std::make_unique<llm_build_nemotron>(*this , params);
@@ -7518,6 +7591,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
75187591 case LLM_ARCH_GPTNEOX:
75197592 case LLM_ARCH_CODESHELL:
75207593 case LLM_ARCH_ORION:
7594+ case LLM_ARCH_MEGREZ_MOE:
75217595 case LLM_ARCH_NEMOTRON:
75227596 case LLM_ARCH_EXAONE:
75237597 case LLM_ARCH_EXAONE4:
0 commit comments