@@ -2708,8 +2708,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
27082708 layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, 3 * n_embd }, 0);
27092709 layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
27102710
2711- layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_ff, n_embd }, 0); // [3072, 384]
2712- layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_embd, 2 * n_ff }, 0);
2711+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, 2 * n_ff }, 0);
2712+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd }, 0);
27132713 layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
27142714 }
27152715 } break;
@@ -7548,6 +7548,7 @@ struct llm_build_modern_bert : public llm_graph_context {
75487548 const int64_t n_embd_head = hparams.n_embd_head_v;
75497549 const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); // == n_head_kv * n_embd_head
75507550 const int64_t n_tokens = ubatch.n_tokens;
7551+ const int64_t n_ff = hparams.n_ff();
75517552
75527553 GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
75537554
@@ -7667,30 +7668,63 @@ struct llm_build_modern_bert : public llm_graph_context {
76677668
76687669 // MLP (prefer GEGLU if gate exists or up has 2*n_ff rows)
76697670 ggml_tensor * mlp_out = nullptr;
7670- const bool has_gate_tensor = (model.layers[il].ffn_gate != nullptr);
7671- const bool up_is_2x = (model.layers[il].ffn_up && model.layers[il].ffn_up->ne[0] == 2*hparams.n_ff());
7671+ ggml_tensor * ffn_gate_view = model.layers[il].ffn_gate;
7672+ ggml_tensor * ffn_up_view = model.layers[il].ffn_up;
7673+
7674+ if (ffn_gate_view == nullptr && ffn_up_view) {
7675+
7676+ // Case A: weight stored as (2*ffn, hidden) -> split rows into two (ffn x hidden)
7677+ if( ffn_up_view->ne[0] == 2 * n_ff and ffn_up_view->ne[1] == n_embd) {
7678+
7679+ // top half, (ffn up)
7680+ ffn_up_view = ggml_view_2d(ctx0, model.layers[il].ffn_up,
7681+ /*ne0*/ n_ff, /*ne1*/ n_embd,
7682+ /*nb1*/ model.layers[il].ffn_up->nb[1],
7683+ /*offset_bytes*/ (size_t)0);
7684+ // bottom half (gate)
7685+ ffn_gate_view = ggml_view_2d(ctx0, model.layers[il].ffn_up,
7686+ /*ne0*/ n_ff, /*ne1*/ n_embd,
7687+ /*nb1*/ model.layers[il].ffn_up->nb[1],
7688+ /*offset_bytes*/ (size_t)n_ff * model.layers[il].ffn_up->nb[1]);
7689+ }
7690+ else if ( ffn_up_view->ne[0] == n_embd && ffn_up_view->ne[1] == 2 * n_ff) {
7691+ // top half
7692+ ffn_up_view = ggml_view_2d(ctx0, model.layers[il].ffn_up,
7693+ n_embd, n_ff,
7694+ model.layers[il].ffn_up->nb[1],
7695+ 0);
7696+ ffn_up_view = ggml_cont(ctx0, ffn_up_view);
7697+
7698+ ffn_gate_view = ggml_view_2d(ctx0, model.layers[il].ffn_up,
7699+ n_embd, n_ff,
7700+ model.layers[il].ffn_up->nb[1],
7701+ n_ff * sizeof(float));
7702+ ffn_gate_view = ggml_cont(ctx0, ffn_gate_view);
7703+ }
7704+
7705+ ggml_tensor * ffn_down_view = model.layers[il].ffn_down;
7706+ LLAMA_LOG_INFO("ffn shapes: Up: {%lld, %lld}, Gate: {%lld, %lld}, Down: {%lld, %lld}",
7707+ ffn_up_view->ne[0], ffn_up_view->ne[1], ffn_gate_view->ne[0], ffn_gate_view->ne[1], ffn_down_view->ne[0], ffn_down_view->ne[1]);
76727708
7673- if (has_gate_tensor || up_is_2x) {
76747709 mlp_out = build_ffn(
76757710 h,
76767711 model.layers[il].ffn_up, /*up_b*/ NULL, /*up_shexp*/ NULL,
7677- model.layers[il].ffn_gate, /*gate_b*/ NULL, /*gate_shexp*/ NULL,
7712+ ffn_gate_view , /*gate_b*/ NULL, /*gate_shexp*/ NULL,
76787713 model.layers[il].ffn_down, /*down_b*/ NULL, /*down_shexp*/ NULL,
76797714 /*expert_scores*/ NULL,
7680- LLM_FFN_GEGLU, LLM_FFN_PAR, il);
7681- cb(mlp_out, "ffn_out_geglu", il);
7715+ LLM_FFN_GEGLU, LLM_FFN_PAR, il
7716+ );
7717+ cb(mlp_out, "ffn_out_geglu", il);
76827718 } else {
7683-
7684- LLAMA_LOG_INFO("Ffn_up : {%lld, %lld}, ffn_down : {%lld, %lld}\n", model.layers[il].ffn_up->ne[0], model.layers[il].ffn_up->ne[1],
7685- model.layers[il].ffn_down->ne[0], model.layers[il].ffn_down->ne[0]);
76867719 mlp_out = build_ffn(
76877720 h,
7688- model.layers[il].ffn_up, /*up_b*/ NULL, /*up_shexp*/ NULL,
7689- /*gate*/ NULL, /*gate_b*/ NULL, /*gate_shexp*/ NULL,
7690- model.layers[il].ffn_down, /*down_b*/ NULL, /*down_shexp*/ NULL,
7691- /*expert_scores*/ NULL,
7692- LLM_FFN_GELU, LLM_FFN_SEQ, il);
7693- cb(mlp_out, "ffn_out_gelu", il);
7721+ model.layers[il].ffn_up, NULL, NULL,
7722+ model.layers[il].ffn_gate, NULL, NULL,
7723+ model.layers[il].ffn_down, NULL, NULL,
7724+ NULL,
7725+ LLM_FFN_GEGLU, LLM_FFN_PAR, il
7726+ );
7727+ cb(mlp_out, "ffn_out_geglu", il);
76947728 }
76957729
76967730 // Residual after MLP
0 commit comments