Skip to content

Commit 3e18442

Browse files
committed
llama-model : implement GLM4 MoE inference graph
1 parent d048901 commit 3e18442

File tree

1 file changed

+136
-1
lines changed

1 file changed

+136
-1
lines changed

src/llama-model.cpp

Lines changed: 136 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13651,7 +13651,142 @@ struct llm_build_glm4 : public llm_graph_context {
1365113651

1365213652
struct llm_build_glm4_moe : public llm_graph_context {
1365313653
llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
13654-
// TODO
13654+
const int64_t n_embd_head = hparams.n_embd_head_v;
13655+
const int64_t n_rot = hparams.n_rot;
13656+
13657+
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
13658+
GGML_ASSERT(n_rot == n_embd_head / 2);
13659+
13660+
ggml_tensor * cur;
13661+
ggml_tensor * inpL;
13662+
13663+
inpL = build_inp_embd(model.tok_embd);
13664+
13665+
ggml_tensor * inp_pos = build_inp_pos();
13666+
auto * inp_attn = build_attn_inp_kv_unified();
13667+
ggml_tensor * inp_out_ids = build_inp_out_ids();
13668+
13669+
for (int il = 0; il < n_layer; ++il) {
13670+
ggml_tensor * inpSA = inpL;
13671+
13672+
// pre-attention norm
13673+
cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
13674+
cb(cur, "attn_norm", il);
13675+
13676+
// self-attention block
13677+
{
13678+
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
13679+
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
13680+
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
13681+
cb(Qcur, "Qcur", il);
13682+
cb(Kcur, "Kcur", il);
13683+
cb(Vcur, "Vcur", il);
13684+
13685+
// optional QK norm
13686+
if (hparams.use_kq_norm) {
13687+
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
13688+
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
13689+
13690+
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
13691+
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
13692+
cb(Qcur, "Qcur_normed", il);
13693+
cb(Kcur, "Kcur_normed", il);
13694+
}
13695+
13696+
// reshape QKV
13697+
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
13698+
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
13699+
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
13700+
13701+
// apply RoPE
13702+
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
13703+
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
13704+
cb(Qcur, "Qcur_roped", il);
13705+
cb(Kcur, "Kcur_roped", il);
13706+
13707+
const float kq_scale = 1.0f / sqrtf(float(n_embd_head));
13708+
13709+
cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
13710+
cb(cur, "attn_out", il);
13711+
}
13712+
13713+
// first residual
13714+
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
13715+
cb(ffn_inp, "ffn_inp", il);
13716+
13717+
// pre-ffn RMSnorm
13718+
cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
13719+
cb(cur, "ffn_norm", il);
13720+
13721+
//
13722+
if (static_cast<uint32_t>(il) < hparams.n_layer_dense_lead) {
13723+
// dense FFN
13724+
cur = build_ffn(cur,
13725+
model.layers[il].ffn_up, NULL, NULL,
13726+
model.layers[il].ffn_gate, NULL, NULL,
13727+
model.layers[il].ffn_down, NULL, NULL,
13728+
NULL,
13729+
LLM_FFN_SILU, LLM_FFN_PAR, il);
13730+
cb(cur, "ffn_dense_out", il);
13731+
} else {
13732+
// shared expert
13733+
ggml_tensor * shexp_out = build_ffn(cur,
13734+
model.layers[il].ffn_up_shexp, NULL, NULL,
13735+
model.layers[il].ffn_gate_shexp, NULL, NULL,
13736+
model.layers[il].ffn_down_shexp, NULL, NULL,
13737+
NULL,
13738+
LLM_FFN_SILU, LLM_FFN_PAR, il);
13739+
cb(shexp_out, "ffn_shexp_out", il);
13740+
13741+
// conditional experts
13742+
ggml_tensor * moe_out = build_moe_ffn(cur,
13743+
model.layers[il].ffn_gate_inp,
13744+
model.layers[il].ffn_up_exps,
13745+
model.layers[il].ffn_gate_exps,
13746+
model.layers[il].ffn_down_exps,
13747+
model.layers[il].ffn_exp_probs_b,
13748+
n_expert, n_expert_used,
13749+
LLM_FFN_SILU,
13750+
true, // norm_topk_prob
13751+
true, // use expert bias
13752+
hparams.expert_weights_scale,
13753+
LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID, // IMPORTANT -- MUST USE SIGMOID
13754+
il);
13755+
cb(moe_out, "ffn_moe_out", il);
13756+
13757+
// combine output from shared and routed experts
13758+
cur = ggml_add(ctx0, moe_out, shexp_out);
13759+
cb(cur, "ffn_moe_combined", il);
13760+
}
13761+
13762+
if (il == n_layer - 1 && inp_out_ids) {
13763+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
13764+
ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
13765+
}
13766+
13767+
// second residual
13768+
cur = ggml_add(ctx0, cur, ffn_inp);
13769+
13770+
cur = build_cvec(cur, il);
13771+
cb(cur, "l_out", il);
13772+
13773+
// input for next layer
13774+
inpL = cur;
13775+
}
13776+
13777+
cur = inpL;
13778+
13779+
// output norm
13780+
cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
13781+
cb(cur, "output_norm", -1);
13782+
res->t_embd = cur;
13783+
13784+
// final output
13785+
cur = build_lora_mm(model.output, cur);
13786+
cb(cur, "output", -1);
13787+
res->t_logits = cur;
13788+
13789+
ggml_build_forward_expand(gf, cur);
1365513790
}
1365613791
};
1365713792

0 commit comments

Comments
 (0)