@@ -4763,10 +4763,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
47634763 output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
47644764 }
47654765
4766- GGML_ASSERT(hparams.n_moe_layer_step > 0 && "Ernie 4.5 MoE requires n_moe_layer_step > 0");
47674766 for (int i = 0; i < n_layer; ++i) {
4768- bool is_moe_layer = (i + 1) % hparams.n_moe_layer_step == 0;
4769-
47704767 auto & layer = layers[i];
47714768
47724769 layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
@@ -4784,7 +4781,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
47844781
47854782 layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
47864783
4787- if (is_moe_layer) {
4784+ // Ernie 4.5 MoE has some dense layers, so we check for the existence of the gate tensor
4785+ if (ml.get_tensor_meta(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i).str().c_str())) {
47884786 int n_ff_exp = hparams.n_ff_exp;
47894787
47904788 layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
@@ -8346,158 +8344,171 @@ struct llm_build_phi2 : public llm_graph_context {
83468344 };
83478345
83488346struct llm_build_ernie4_5_moe : public llm_graph_context {
8349- llm_build_ernie4_5_moe(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
8350- const int64_t n_embd_head = hparams.n_embd_head_v;
8347+ llm_build_ernie4_5_moe(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
8348+ const int64_t n_embd_head = hparams.n_embd_head_v;
83518349
8352- GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
8353- GGML_ASSERT(n_embd_head == hparams.n_rot);
8350+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
8351+ GGML_ASSERT(n_embd_head == hparams.n_rot);
83548352
8355- ggml_tensor * cur;
8356- ggml_tensor * inpL;
8353+ ggml_tensor * cur;
8354+ ggml_tensor * inpL;
83578355
8358- inpL = build_inp_embd(model.tok_embd);
8356+ inpL = build_inp_embd(model.tok_embd);
83598357
8360- // inp_pos - contains the positions
8361- ggml_tensor * inp_pos = build_inp_pos();
8358+ // inp_pos - contains the positions
8359+ ggml_tensor * inp_pos = build_inp_pos();
83628360
8363- auto * inp_attn = build_attn_inp_kv_unified();
8361+ auto * inp_attn = build_attn_inp_kv_unified();
83648362
8365- ggml_tensor * inp_out_ids = build_inp_out_ids();
8363+ ggml_tensor * inp_out_ids = build_inp_out_ids();
83668364
8367- for (int il = 0; il < n_layer; ++il) {
8368- ggml_tensor * inpSA = inpL;
8365+ for (int il = 0; il < n_layer; ++il) {
8366+ ggml_tensor * inpSA = inpL;
83698367
8370- // norm
8371- {
8372- cur = build_norm(inpL,
8373- model.layers[il].attn_norm, NULL,
8374- LLM_NORM_RMS, il);
8375- cb(cur, "attn_norm", il);
8376- }
8368+ // norm
8369+ {
8370+ cur = build_norm(inpL,
8371+ model.layers[il].attn_norm, NULL,
8372+ LLM_NORM_RMS, il);
8373+ cb(cur, "attn_norm", il);
8374+ }
83778375
8378- // self-attention
8379- {
8380- // compute Q and K and RoPE them
8381- ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
8382- cb(Qcur, "Qcur", il);
8383- if (model.layers[il].bq) {
8384- Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
8376+ // self-attention
8377+ {
8378+ // compute Q and K and RoPE them
8379+ ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
83858380 cb(Qcur, "Qcur", il);
8386- }
8381+ if (model.layers[il].bq) {
8382+ Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
8383+ cb(Qcur, "Qcur", il);
8384+ }
83878385
8388- ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
8389- cb(Kcur, "Kcur", il);
8390- if (model.layers[il].bk) {
8391- Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
8386+ ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
83928387 cb(Kcur, "Kcur", il);
8393- }
8388+ if (model.layers[il].bk) {
8389+ Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
8390+ cb(Kcur, "Kcur", il);
8391+ }
83948392
8395- ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
8396- cb(Vcur, "Vcur", il);
8397- if (model.layers[il].bv) {
8398- Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
8393+ ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
83998394 cb(Vcur, "Vcur", il);
8400- }
8395+ if (model.layers[il].bv) {
8396+ Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
8397+ cb(Vcur, "Vcur", il);
8398+ }
84018399
8402- Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
8403- Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
8404- Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
8400+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
8401+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
8402+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
84058403
8406- Qcur = ggml_rope_ext(
8407- ctx0, Qcur, inp_pos, nullptr,
8408- n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
8409- ext_factor, attn_factor, beta_fast, beta_slow
8410- );
8404+ Qcur = ggml_rope_ext(
8405+ ctx0, Qcur, inp_pos, nullptr,
8406+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
8407+ ext_factor, attn_factor, beta_fast, beta_slow
8408+ );
84118409
8412- Kcur = ggml_rope_ext(
8413- ctx0, Kcur, inp_pos, nullptr,
8414- n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
8415- ext_factor, attn_factor, beta_fast, beta_slow
8416- );
8410+ Kcur = ggml_rope_ext(
8411+ ctx0, Kcur, inp_pos, nullptr,
8412+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
8413+ ext_factor, attn_factor, beta_fast, beta_slow
8414+ );
84178415
8418- cb(Qcur, "Qcur", il);
8419- cb(Kcur, "Kcur", il);
8420- cb(Vcur, "Vcur", il);
8416+ cb(Qcur, "Qcur", il);
8417+ cb(Kcur, "Kcur", il);
8418+ cb(Vcur, "Vcur", il);
84218419
8422- cur = build_attn(inp_attn, gf,
8423- model.layers[il].wo, NULL,
8424- Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
8425- cb(cur, "attn_out", il);
8426- }
8420+ cur = build_attn(inp_attn, gf,
8421+ model.layers[il].wo, NULL,
8422+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
8423+ cb(cur, "attn_out", il);
8424+ }
84278425
8428- if (il == n_layer - 1 && inp_out_ids) {
8429- cur = ggml_get_rows(ctx0, cur, inp_out_ids);
8430- inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
8431- }
8426+ if (il == n_layer - 1 && inp_out_ids) {
8427+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
8428+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
8429+ }
84328430
8433- ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
8434- cb(ffn_inp, "ffn_inp", il);
8431+ ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
8432+ cb(ffn_inp, "ffn_inp", il);
84358433
8436- // MoE feed-forward network
8437- {
8438- cur = build_norm(ffn_inp,
8439- model.layers[il].ffn_norm, NULL,
8440- LLM_NORM_RMS, il);
8441- cb(cur, "ffn_norm", il);
8434+ // feed-forward network
8435+ if (model.layers[il].ffn_gate_inp == nullptr) {
8436+ cur = build_norm(ffn_inp,
8437+ model.layers[il].ffn_norm, NULL,
8438+ LLM_NORM_RMS, il);
8439+ cb(cur, "ffn_norm", il);
84428440
8443- // MoE branch
8444- ggml_tensor * moe_out = build_moe_ffn(cur ,
8445- model.layers[il].ffn_gate_inp ,
8446- model.layers[il].ffn_up_exps ,
8447- model.layers[il].ffn_gate_exps ,
8448- model.layers[il].ffn_down_exps,
8449- nullptr,
8450- n_expert, n_expert_used,
8451- LLM_FFN_SILU, true,
8452- false, 0.0 ,
8453- LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX ,
8454- il);
8455- cb(moe_out , "ffn_moe_out ", il);
8441+ cur = build_ffn(cur,
8442+ model.layers[il].ffn_up, NULL, NULL ,
8443+ model.layers[il].ffn_gate, NULL, NULL ,
8444+ model.layers[il].ffn_down, NULL, NULL ,
8445+ NULL ,
8446+ LLM_FFN_SILU, LLM_FFN_PAR, il);
8447+ cb(cur, "ffn_out", il);
8448+ } else {
8449+ // MoE branch
8450+ cur = build_norm(ffn_inp ,
8451+ model.layers[il].ffn_norm, NULL ,
8452+ LLM_NORM_RMS, il);
8453+ cb(cur , "ffn_norm ", il);
84568454
8457- // Shared expert (if present)
8458- if (hparams.n_ff_shexp > 0) {
8459- ggml_tensor * ffn_shexp = build_ffn(cur,
8460- model.layers[il].ffn_up_shexp, NULL, NULL,
8461- model.layers[il].ffn_gate_shexp, NULL, NULL,
8462- model.layers[il].ffn_down_shexp, NULL, NULL,
8463- NULL,
8464- LLM_FFN_SILU, LLM_FFN_PAR, il);
8465- cb(ffn_shexp, "ffn_shexp", il);
8455+ ggml_tensor * moe_out = build_moe_ffn(cur,
8456+ model.layers[il].ffn_gate_inp,
8457+ model.layers[il].ffn_up_exps,
8458+ model.layers[il].ffn_gate_exps,
8459+ model.layers[il].ffn_down_exps,
8460+ nullptr,
8461+ n_expert, n_expert_used,
8462+ LLM_FFN_SILU, true,
8463+ false, 0.0,
8464+ LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
8465+ il);
8466+ cb(moe_out, "ffn_moe_out", il);
84668467
8467- cur = ggml_add(ctx0, moe_out, ffn_shexp);
8468- } else {
8469- cur = moe_out;
8468+ // Shared expert (if present)
8469+ if (hparams.n_ff_shexp > 0) {
8470+ ggml_tensor * ffn_shexp = build_ffn(cur,
8471+ model.layers[il].ffn_up_shexp, NULL, NULL,
8472+ model.layers[il].ffn_gate_shexp, NULL, NULL,
8473+ model.layers[il].ffn_down_shexp, NULL, NULL,
8474+ NULL,
8475+ LLM_FFN_SILU, LLM_FFN_PAR, il);
8476+ cb(ffn_shexp, "ffn_shexp", il);
8477+
8478+ cur = ggml_add(ctx0, moe_out, ffn_shexp);
8479+ } else {
8480+ cur = moe_out;
8481+ }
8482+ cb(cur, "ffn_out", il);
84708483 }
8471- cb(cur, "ffn_out", il);
8472- }
84738484
8474- cur = ggml_add(ctx0, cur, ffn_inp);
8475- cb(cur, "ffn_out", il);
8485+ cur = ggml_add(ctx0, cur, ffn_inp);
8486+ cb(cur, "ffn_out", il);
84768487
8477- cur = build_cvec(cur, il);
8478- cb(cur, "l_out", il);
8488+ cur = build_cvec(cur, il);
8489+ cb(cur, "l_out", il);
84798490
8480- // input for next layer
8481- inpL = cur;
8482- }
8491+ // input for next layer
8492+ inpL = cur;
8493+ }
84838494
8484- cur = inpL;
8495+ cur = inpL;
84858496
8486- cur = build_norm(cur,
8487- model.output_norm, NULL,
8488- LLM_NORM_RMS, -1);
8497+ cur = build_norm(cur,
8498+ model.output_norm, NULL,
8499+ LLM_NORM_RMS, -1);
84898500
8490- cb(cur, "result_norm", -1);
8491- res->t_embd = cur;
8501+ cb(cur, "result_norm", -1);
8502+ res->t_embd = cur;
84928503
8493- // lm_head
8494- cur = build_lora_mm(model.output, cur);
8504+ // lm_head
8505+ cur = build_lora_mm(model.output, cur);
84958506
8496- cb(cur, "result_output", -1);
8497- res->t_logits = cur;
8507+ cb(cur, "result_output", -1);
8508+ res->t_logits = cur;
84988509
8499- ggml_build_forward_expand(gf, cur);
8500- }
8510+ ggml_build_forward_expand(gf, cur);
8511+ }
85018512};
85028513
85038514template<bool iswa>
0 commit comments