@@ -759,11 +759,20 @@ void llama_model::load_hparams(llama_model_loader & ml) {
759759 } break;
760760 case LLM_ARCH_MODERN_BERT:
761761 {
762- //ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
763- LLAMA_LOG_INFO("Switching Modern Bert Arch\n");
762+
763+ hparams.swa_type = LLAMA_SWA_TYPE_LOCAL;
764+
765+ hparams.set_swa_pattern(3, 0);
766+ hparams.n_swa = 128;
767+
768+ ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
769+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
770+ ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
771+ ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
772+
764773 switch (hparams.n_layer) {
765774 case 12:
766- type = LLM_TYPE_47M; break; // granite-embeddings-mall
775+ type = LLM_TYPE_47M; break; // granite-embeddings-small
767776 default: type = LLM_TYPE_UNKNOWN;
768777 }
769778 } break;
@@ -7544,152 +7553,111 @@ struct llm_build_bert : public llm_graph_context {
75447553struct llm_build_modern_bert : public llm_graph_context {
75457554 llm_build_modern_bert(const llama_model & model, const llm_graph_params & params)
75467555 : llm_graph_context(params) {
7547- const int64_t n_embd = hparams.n_embd;
7548- const int64_t n_layer = hparams.n_layer;
7549- const int64_t n_head = hparams.n_head();
7550- const int64_t n_head_kv = hparams.n_head_kv();
7551- const int64_t n_embd_head = hparams.n_embd_head_v;
7552- const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
7553- const int64_t n_tokens = ubatch.n_tokens;
7556+ const int64_t n_embd = hparams.n_embd;
7557+ const int64_t n_layer = hparams.n_layer;
7558+ const int64_t n_head = hparams.n_head();
7559+ const int64_t n_head_kv = hparams.n_head_kv();
7560+ const int64_t n_embd_head = hparams.n_embd_head_v;
7561+ const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
7562+ const int64_t n_tokens = ubatch.n_tokens;
75547563
75557564 GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
75567565
7557- // RoPE params
7558- const int32_t rope_type = LLAMA_ROPE_TYPE_NEOX; // uses rotary
7559- const int32_t n_rot = hparams.n_rot;
7560- const int32_t n_ctx_orig = hparams.n_ctx_train;
7561-
7562- ggml_tensor * cur;
7563- ggml_tensor * inpL;
7564- ggml_tensor * inp_pos = nullptr;
7565-
7566- // needs positions for RoPE
7567- inp_pos = build_inp_pos();
7566+ // rope params
7567+ const int32_t rope_type = LLAMA_ROPE_TYPE_NEOX;
7568+ const int32_t n_rot = hparams.n_rot;
7569+ const int32_t n_ctx_orig = hparams.n_ctx_train;
7570+ const float freq_base = hparams.rope_freq_base_train;
7571+ const float freq_scale = hparams.rope_freq_scale_train;
7572+ const float attn_factor = 1.0f;
7573+ const float ext_factor = 1.0f;
7574+ const float beta_fast = 0.0f;
7575+ const float beta_slow = 0.0f;
75687576
7569- // embeddings (token + optional type), NO absolute pos embed
7570- inpL = build_inp_embd(model.tok_embd);
7577+ ggml_tensor * inp_pos = build_inp_pos();
7578+ ggml_tensor * inpL = build_inp_embd(model.tok_embd);
75717579
75727580 if (model.type_embd) {
7573- ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0);
7574- inpL = ggml_add(ctx0, inpL, type_row0);
7581+ inpL = ggml_add(ctx0, inpL, ggml_view_1d(ctx0, model.type_embd, n_embd, 0));
75757582 }
7576- cb(inpL, "inp_embd", -1);
7577-
7578- // embeddings LayerNorm (embeddings.norm)
75797583 inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
7580- cb(inpL, "inp_norm", -1);
75817584
7582- auto * inp_attn = build_attn_inp_no_cache();
7585+ auto * inp_attn = build_attn_inp_no_cache();
75837586 ggml_tensor * inp_out_ids = build_inp_out_ids();
75847587
75857588 for (int il = 0; il < n_layer; ++il) {
75867589 ggml_tensor * x = inpL;
75877590
7588- // pre attention norm (attn_norm). Layer 0 may be Identity() -> nullptr
7591+ // Pre attention Layer norm
75897592 ggml_tensor * x_attn_in = x;
75907593 if (model.layers[il].attn_norm) {
7591- x_attn_in = build_norm(x,
7592- model.layers[il].attn_norm,
7593- model.layers[il].attn_norm_b,
7594- LLM_NORM, il);
7595- cb(x_attn_in, "attn_pre_norm", il);
7596- } else {
7597- cb(x_attn_in, "attn_pre_norm_identity", il);
7594+ x_attn_in = build_norm(x, model.layers[il].attn_norm, model.layers[il].attn_norm_b, LLM_NORM, il);
75987595 }
75997596
7600- // Attention: fused Wqkv -> split -> heads -> RoPE(Q,K) -> attn -> Wo
7601- ggml_tensor * qkv = nullptr;
7602- ggml_tensor * Qcur;
7603- ggml_tensor * Kcur;
7604- ggml_tensor * Vcur;
7605-
7606- GGML_ASSERT(model.layers[il].wqkv); // fused QKV
7607- qkv = build_lora_mm(model.layers[il].wqkv, x_attn_in);
7608- cb(qkv, "wqkv", il);
7609-
7597+ // fused qkv
7598+ GGML_ASSERT(model.layers[il].wqkv);
7599+ ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, x_attn_in);
76107600 if (model.layers[il].bqkv) {
76117601 qkv = ggml_add(ctx0, qkv, model.layers[il].bqkv);
7612- cb(qkv, "bqkv", il);
76137602 }
76147603
7615- // Fused layout: [ (n_embd + 2*n_embd_gqa), n_tokens ]
7616- Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd, n_tokens, qkv->nb[1], 0*sizeof(float)*(n_embd)));
7617- Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd_gqa, n_tokens, qkv->nb[1], 1*sizeof(float)*(n_embd)));
7618- Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd_gqa, n_tokens, qkv->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
7604+ ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd, n_tokens, qkv->nb[1], 0));
7605+ ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd_gqa, n_tokens, qkv->nb[1], n_embd));
7606+ ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd_gqa, n_tokens, qkv->nb[1], n_embd + n_embd_gqa));
76197607
7620- // optional per Q/K
7621- if (model.layers[il].attn_q_norm) {
7622- Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, model.layers[il].attn_q_norm_b, LLM_NORM, il);
7623- }
7624- if (model.layers[il].attn_k_norm) {
7625- Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, model.layers[il].attn_k_norm_b, LLM_NORM, il);
7626- }
7608+ // optional q/k LayerNorm
7609+ if (model.layers[il].attn_q_norm) Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, model.layers[il].attn_q_norm_b, LLM_NORM, il);
7610+ if (model.layers[il].attn_k_norm) Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, model.layers[il].attn_k_norm_b, LLM_NORM, il);
76277611
7628- // heads
7612+ // reshape for multi head
76297613 Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
76307614 Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
76317615 Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
76327616
7633- // RoPE (NEOX ... maybe?) on Q and K
7617+ // rope embedding
76347618 Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr,
7635- n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
7636- ext_factor, attn_factor, beta_fast, beta_slow);
7619+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
7620+ ext_factor, attn_factor, beta_fast, beta_slow);
76377621 Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr,
7638- n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
7639- ext_factor, attn_factor, beta_fast, beta_slow);
7640-
7641- cb(Qcur, "Qcur_rope", il);
7642- cb(Kcur, "Kcur_rope", il);
7643- cb(Vcur, "Vcur", il);
7622+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
7623+ ext_factor, attn_factor, beta_fast, beta_slow);
76447624
76457625 ggml_tensor * attn_out = build_attn(
76467626 inp_attn,
7647- model.layers[il].wo, model.layers[il].bo, // Wo, optional bias
7627+ model.layers[il].wo, model.layers[il].bo,
76487628 Qcur, Kcur, Vcur,
7649- /*K_cache */ nullptr,
7650- /*V_cache */ nullptr,
7629+ /*k cache */ nullptr,
7630+ /*v cache */ nullptr,
76517631 1.0f / sqrtf(float(n_embd_head)),
7652- il);
7653- cb(attn_out, "attn_out", il );
7632+ il
7633+ );
76547634
7655- // residual after attention
76567635 ggml_tensor * cur_attn = ggml_add(ctx0, attn_out, x);
76577636
7658- // ifwe subselect outputs, do it at the last layer after attn resid
7637+ // optional subselect output tokens (inp_out_ids)
76597638 if (il == n_layer - 1 && inp_out_ids) {
7660- cur_attn = ggml_get_rows(ctx0, cur_attn, inp_out_ids);
7661- x = ggml_get_rows(ctx0, x, inp_out_ids);
7639+ cur_attn = ggml_get_rows(ctx0, cur_attn, inp_out_ids);
7640+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
76627641 }
76637642
7664- // pre mlp norm
7665- ggml_tensor * h = build_norm(cur_attn,
7666- model.layers[il].ffn_norm,
7667- model.layers[il].ffn_norm_b,
7668- LLM_NORM, il);
7669- cb(h, "mlp_pre_norm", il);
7643+ // pre mlp LayerNorm
7644+ ggml_tensor * h = build_norm(cur_attn, model.layers[il].ffn_norm, model.layers[il].ffn_norm_b, LLM_NORM, il);
76707645
7671- // GEGLU because we will split ffn_up which has shape [n_embd, n_ff * 2] and ffn_down has shape [n_ff, n_embd]
7646+ // geglu FFN
76727647 ggml_tensor * mlp_out = build_ffn(
76737648 h,
7674- model.layers[il].ffn_up, /*up_b*/ NULL, /*up_shexp*/ NULL,
7675- /*gate*/ NULL , /*gate_b*/ NULL, /*gate_shexp*/ NULL,
7676- model.layers[il].ffn_down, /*down_b*/ NULL, /*down_shexp*/ NULL,
7677- /*act_scales*/ NULL,
7649+ model.layers[il].ffn_up, NULL, NULL,
7650+ NULL, NULL, NULL,
7651+ model.layers[il].ffn_down, NULL, NULL,
7652+ NULL,
76787653 LLM_FFN_GEGLU, LLM_FFN_PAR, il
76797654 );
76807655
7681- cb(mlp_out, "ffn_out_geglu", il);
7682- // Residual after MLP
7683- ggml_tensor * cur_layer = ggml_add(ctx0, mlp_out, cur_attn);
7684-
7685- // feed into next layer
7686- inpL = cur_layer;
7656+ // resid addition
7657+ inpL = ggml_add(ctx0, mlp_out, cur_attn);
76877658 }
76887659
7689- // final model norm (final_norm)
7690- cur = build_norm(inpL, model.output_norm, model.output_norm_b, LLM_NORM, -1);
7691- cb(cur, "final_norm", -1);
7692-
7660+ ggml_tensor * cur = build_norm(inpL, model.output_norm, model.output_norm_b, LLM_NORM, -1);
76937661 res->t_embd = cur;
76947662 ggml_build_forward_expand(gf, cur);
76957663 }
0 commit comments