Skip to content

Commit c00b183

Browse files
tamarPaltamarPal
authored andcommitted
feat: adapt Megrez-MoE to new models/*.cpp architecture
- Move llm_build_megrez_moe from llama-model.cpp to src/models/megrez-moe.cpp - Add declaration to src/models/models.h - Update CMakeLists.txt to include megrez-moe.cpp in build - Resolve merge conflicts in llama-arch.cpp and llama-model.cpp - Fix PANGU_EMBED case statement closing braces The model loads successfully, all tests pass (40/40), and inference works correctly.
1 parent d7443ba commit c00b183

File tree

5 files changed

+225
-1
lines changed

5 files changed

+225
-1
lines changed

src/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ add_library(llama
8989
models/mamba.cpp
9090
models/minicpm3.cpp
9191
models/minimax-m2.cpp
92+
models/megrez-moe.cpp
9293
models/mpt.cpp
9394
models/nemotron-h.cpp
9495
models/nemotron.cpp

src/llama-arch.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2402,7 +2402,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
24022402
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
24032403
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
24042404
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
2405-
>>>>>>> 256414a18 (feat: Add Megrez-MoE architecture support)
24062405
},
24072406
},
24082407
{

src/llama-model.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2180,9 +2180,13 @@ void llama_model::load_hparams(llama_model_loader & ml) {
21802180
case LLM_ARCH_PANGU_EMBED:
21812181
{
21822182
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
2183+
21832184
switch (hparams.n_layer) {
21842185
case 26: type = LLM_TYPE_1B; break; // openPangu-Embedded-1B-V1.1
21852186
case 34: type = LLM_TYPE_7B; break; // openPangu-Embedded-7B-V1.1
2187+
default: type = LLM_TYPE_UNKNOWN;
2188+
}
2189+
} break;
21862190
case LLM_ARCH_MEGREZ_MOE:
21872191
{
21882192
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);

src/models/megrez-moe.cpp

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
#include "models.h"
2+
3+
4+
5+
llm_build_megrez_moe::llm_build_megrez_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params){
6+
const int64_t n_embd_head = hparams.n_embd_head_v;
7+
8+
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
9+
GGML_ASSERT(n_embd_head == hparams.n_rot);
10+
11+
ggml_tensor * cur;
12+
ggml_tensor * inpL;
13+
14+
inpL = build_inp_embd(model.tok_embd);
15+
16+
// inp_pos - contains the positions
17+
ggml_tensor * inp_pos = build_inp_pos();
18+
19+
auto * inp_attn = build_attn_inp_kv();
20+
21+
const float kq_scale = 1.0f/sqrtf(float(n_embd_head));
22+
23+
ggml_tensor * pre_gate_hidden;
24+
// Layer 0
25+
{
26+
ggml_tensor * inpSA = inpL;
27+
28+
// norm
29+
cur = build_norm(inpL,
30+
model.layers[0].attn_norm, NULL,
31+
LLM_NORM_RMS, 0);
32+
cb(cur, "attn_norm", 0);
33+
34+
// compute Q and K and RoPE them
35+
ggml_tensor * Qcur = build_lora_mm(model.layers[0].wq, cur);
36+
cb(Qcur, "Qcur", 0);
37+
38+
ggml_tensor * Kcur = build_lora_mm(model.layers[0].wk, cur);
39+
cb(Kcur, "Kcur", 0);
40+
41+
ggml_tensor * Vcur = build_lora_mm(model.layers[0].wv, cur);
42+
cb(Vcur, "Vcur", 0);
43+
44+
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
45+
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
46+
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
47+
48+
Qcur = ggml_rope_ext(
49+
ctx0, Qcur, inp_pos, nullptr,
50+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
51+
ext_factor, attn_factor, beta_fast, beta_slow
52+
);
53+
54+
Kcur = ggml_rope_ext(
55+
ctx0, Kcur, inp_pos, nullptr,
56+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
57+
ext_factor, attn_factor, beta_fast, beta_slow
58+
);
59+
60+
cb(Qcur, "Qcur", 0);
61+
cb(Kcur, "Kcur", 0);
62+
cb(Vcur, "Vcur", 0);
63+
64+
cur = build_attn(inp_attn,
65+
model.layers[0].wo, NULL,
66+
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, 0);
67+
68+
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
69+
cb(ffn_inp, "ffn_inp", 0);
70+
71+
// feed-forward network
72+
cur = build_norm(ffn_inp,
73+
model.layers[0].ffn_norm, NULL,
74+
LLM_NORM_RMS, 0);
75+
cb(cur, "ffn_norm", 0);
76+
77+
pre_gate_hidden = cur;
78+
79+
cur = build_ffn(cur,
80+
model.layers[0].ffn_up, NULL, NULL,
81+
model.layers[0].ffn_gate, NULL, NULL,
82+
model.layers[0].ffn_down, NULL, NULL,
83+
NULL,
84+
LLM_FFN_SILU, LLM_FFN_PAR, 0);
85+
86+
cb(cur, "ffn_out", 0);
87+
88+
cur = ggml_add(ctx0, cur, ffn_inp);
89+
cb(cur, "ffn_out_add", 0);
90+
91+
}
92+
inpL = cur;
93+
for (int il = 1; il < n_layer; ++il) {
94+
ggml_tensor * inpSA = inpL;
95+
96+
// norm
97+
cur = build_norm(cur,
98+
model.layers[il].attn_norm, NULL,
99+
LLM_NORM_RMS, il);
100+
cb(cur, "attn_norm", il);
101+
102+
// self-attention
103+
{
104+
// compute Q and K and RoPE them
105+
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
106+
cb(Qcur, "Qcur", il);
107+
108+
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
109+
cb(Kcur, "Kcur", il);
110+
111+
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
112+
cb(Vcur, "Vcur", il);
113+
114+
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
115+
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
116+
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
117+
118+
Qcur = ggml_rope_ext(
119+
ctx0, Qcur, inp_pos, nullptr,
120+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
121+
ext_factor, attn_factor, beta_fast, beta_slow
122+
);
123+
124+
Kcur = ggml_rope_ext(
125+
ctx0, Kcur, inp_pos, nullptr,
126+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
127+
ext_factor, attn_factor, beta_fast, beta_slow
128+
);
129+
130+
cb(Qcur, "Qcur", il);
131+
cb(Kcur, "Kcur", il);
132+
cb(Vcur, "Vcur", il);
133+
134+
cur = build_attn(inp_attn,
135+
model.layers[il].wo, NULL,
136+
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
137+
}
138+
139+
if (il == n_layer - 1) {
140+
// skip computing output for unused tokens
141+
ggml_tensor * inp_out_ids = build_inp_out_ids();
142+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
143+
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
144+
pre_gate_hidden = ggml_get_rows(ctx0, pre_gate_hidden, inp_out_ids);
145+
}
146+
147+
148+
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
149+
cb(ffn_inp, "ffn_inp", il);
150+
151+
cur = build_norm(ffn_inp,
152+
model.layers[il].ffn_norm, NULL,
153+
LLM_NORM_RMS, il);
154+
cb(cur, "ffn_norm", il);
155+
156+
if ((uint32_t) il < hparams.n_layer_dense_lead) {
157+
cur = build_ffn(cur,
158+
model.layers[il].ffn_up, NULL, NULL,
159+
model.layers[il].ffn_gate, NULL, NULL,
160+
model.layers[il].ffn_down, NULL, NULL,
161+
NULL,
162+
LLM_FFN_SILU, LLM_FFN_PAR, il);
163+
cb(cur, "ffn_out", il);
164+
} else {
165+
// MoE branch
166+
ggml_tensor * moe_out = build_mergez_moe_ffn(cur,
167+
pre_gate_hidden,
168+
model.layers[il].ffn_gate_inp, model.layers[il].ffn_exp_probs_b,
169+
model.layers[((il - 1) / (3) * (3)) + 1].ffn_up_exps,
170+
model.layers[((il - 1) / (3) * (3)) + 1].ffn_gate_exps,
171+
model.layers[((il - 1) / (3) * (3)) + 1].ffn_down_exps,
172+
n_expert, n_expert_used,
173+
il);
174+
cb(moe_out, "ffn_moe_out", il);
175+
176+
pre_gate_hidden = cur;
177+
178+
// FFN shared expert
179+
{
180+
ggml_tensor * ffn_shexp = build_ffn(cur,
181+
model.layers[il].ffn_up_shexp, NULL, NULL,
182+
model.layers[il].ffn_gate_shexp, NULL, NULL,
183+
model.layers[il].ffn_down_shexp, NULL, NULL,
184+
NULL,
185+
LLM_FFN_SILU, LLM_FFN_PAR, il);
186+
cb(ffn_shexp, "ffn_shexp", il);
187+
188+
cur = ggml_add(ctx0, moe_out, ffn_shexp);
189+
cb(cur, "ffn_out", il);
190+
}
191+
}
192+
193+
cur = ggml_add(ctx0, cur, ffn_inp);
194+
195+
cb(cur, "l_out", il);
196+
197+
// input for next layer
198+
inpL = cur;
199+
}
200+
201+
cur = inpL;
202+
203+
cur = build_norm(cur,
204+
model.output_norm, NULL,
205+
LLM_NORM_RMS, -1);
206+
cb(cur, "result_norm", -1);
207+
res->t_embd = cur;
208+
209+
// lm_head
210+
cur = build_lora_mm(model.output, cur);
211+
212+
cb(cur, "result_output", -1);
213+
res->t_logits = cur;
214+
215+
ggml_build_forward_expand(gf, cur);
216+
}

src/models/models.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,10 @@ struct llm_build_minimax_m2 : public llm_graph_context {
317317
llm_build_minimax_m2(const llama_model & model, const llm_graph_params & params);
318318
};
319319

320+
struct llm_build_megrez_moe : public llm_graph_context {
321+
llm_build_megrez_moe(const llama_model & model, const llm_graph_params & params);
322+
};
323+
320324
struct llm_build_mpt : public llm_graph_context {
321325
llm_build_mpt(const llama_model & model, const llm_graph_params & params);
322326
};

0 commit comments

Comments
 (0)