@@ -7559,6 +7559,7 @@ struct llm_build_modern_bert : public llm_graph_context {
75597559 const int64_t n_head_kv = hparams.n_head_kv();
75607560 const int64_t n_embd_head = hparams.n_embd_head_v;
75617561 const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
7562+ const int64_t n_local_swa = hparams.n_swa;
75627563 const int64_t n_tokens = ubatch.n_tokens;
75637564
75647565 GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -7574,27 +7575,38 @@ struct llm_build_modern_bert : public llm_graph_context {
75747575 const float beta_fast = 0.0f;
75757576 const float beta_slow = 0.0f;
75767577
7577- ggml_tensor * inp_pos = build_inp_pos();
7578+
7579+ ggml_tensor *inp_pos_global = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, 4096, 1);
7580+ ggml_set_input(inp_pos_global);
7581+ size_t element_size = ggml_type_size(inp_pos_global->type);
7582+
7583+ size_t nb1 = element_size;
7584+ size_t nb2 = nb1;
7585+
7586+ inp_pos_global = ggml_view_3d(ctx0, inp_pos_global, 1, 1, 4096, nb1, nb2, 0);
7587+ inp_pos_global = ggml_cont(ctx0, inp_pos_global);
7588+
75787589 ggml_tensor * inpL = build_inp_embd(model.tok_embd);
75797590
75807591 if (model.type_embd) {
75817592 inpL = ggml_add(ctx0, inpL, ggml_view_1d(ctx0, model.type_embd, n_embd, 0));
75827593 }
75837594 inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
75847595
7585- auto * inp_attn = build_attn_inp_no_cache ();
7596+ auto * inp_attn = build_attn_inp_kv_unified_iswa ();
75867597 ggml_tensor * inp_out_ids = build_inp_out_ids();
75877598
7599+
75887600 for (int il = 0; il < n_layer; ++il) {
75897601 ggml_tensor * x = inpL;
75907602
7591- // Pre attention Layer norm
7603+ // pre attn LayerNorm
75927604 ggml_tensor * x_attn_in = x;
75937605 if (model.layers[il].attn_norm) {
75947606 x_attn_in = build_norm(x, model.layers[il].attn_norm, model.layers[il].attn_norm_b, LLM_NORM, il);
75957607 }
75967608
7597- // fused qkv
7609+ // fused QKV
75987610 GGML_ASSERT(model.layers[il].wqkv);
75997611 ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, x_attn_in);
76007612 if (model.layers[il].bqkv) {
@@ -7609,41 +7621,120 @@ struct llm_build_modern_bert : public llm_graph_context {
76097621 if (model.layers[il].attn_q_norm) Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, model.layers[il].attn_q_norm_b, LLM_NORM, il);
76107622 if (model.layers[il].attn_k_norm) Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, model.layers[il].attn_k_norm_b, LLM_NORM, il);
76117623
7612- // reshape for multi head
7624+ // reshape for multi- head attention
76137625 Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
76147626 Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
76157627 Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
76167628
7617- // rope embedding
7618- Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr,
7619- n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
7620- ext_factor, attn_factor, beta_fast, beta_slow);
7621- Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr,
7622- n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
7629+ // global or local layer
7630+ bool is_global = ((il + 1) % 3 == 0);
7631+ float freq_base_l = is_global ? 160000.0f : 10000.0f; // rope theta
7632+ float freq_scale_l = 1.0f;
7633+
7634+ ggml_tensor * pos_q = inp_pos_global;
7635+
7636+ ggml_tensor * K_work = Kcur;
7637+ ggml_tensor * V_work = Vcur;
7638+ ggml_tensor * pos_k = inp_pos_global;
7639+
7640+ if (!is_global) {
7641+ ggml_tensor * idx_src = inp_attn->self_k_idxs_swa;
7642+
7643+ ggml_tensor * idx_view1d = ggml_view_1d(ctx0, idx_src, idx_src->ne[0], 0);
7644+ ggml_tensor * idx_cont = ggml_cont(ctx0, idx_view1d);
7645+
7646+ ggml_tensor * idx_i32 = idx_cont;
7647+ if (idx_i32->type != GGML_TYPE_I32) {
7648+ idx_i32 = ggml_cast(ctx0, idx_cont, GGML_TYPE_I32);
7649+ }
7650+
7651+ const int64_t n_indices = idx_i32->ne[0];
7652+ ggml_tensor * idx_2d = ggml_view_2d(ctx0, idx_i32, 1, n_indices, sizeof(int32_t), 0);
7653+
7654+ idx_2d = ggml_cont(ctx0, idx_2d);
7655+ if (idx_2d->type != GGML_TYPE_I32) idx_2d = ggml_cast(ctx0, idx_2d, GGML_TYPE_I32);
7656+
7657+ Kcur->ne[0], Kcur->ne[1], Kcur->ne[2],
7658+ idx_2d->ne[0], idx_2d->ne[1], idx_2d->ne[2], idx_2d->ne[3],
7659+ idx_2d->type);
7660+
7661+ K_work = ggml_get_rows(ctx0, Kcur, idx_2d);
7662+ V_work = ggml_get_rows(ctx0, Vcur, idx_2d);
7663+
7664+
7665+
7666+ ggml_tensor * pos_rows = ggml_get_rows(ctx0, inp_pos_global, idx_2d);
7667+
7668+ if (!ggml_is_vector(pos_rows)) {
7669+ const int64_t n_el = ggml_nelements(pos_rows);
7670+ pos_rows = ggml_view_1d(ctx0, pos_rows, n_el, 0);
7671+ pos_rows = ggml_cont(ctx0, pos_rows);
7672+ } else {
7673+ pos_rows = ggml_cont(ctx0, pos_rows);
7674+ }
7675+ // ensure I32
7676+ if (pos_rows->type != GGML_TYPE_I32) {
7677+ pos_rows = ggml_cast(ctx0, pos_rows, GGML_TYPE_I32);
7678+ }
7679+
7680+ // final pos_k to pass to rope
7681+ pos_k = pos_rows;
7682+ LLAMA_LOG_INFO("pos_k final: ne[0]=%lld, type=%d\n", pos_k->ne[0], pos_k->type);
7683+ }
7684+
7685+ if( !ggml_is_vector(pos_q) ) {
7686+ const int64_t n_el = ggml_nelements(pos_q);
7687+ pos_q = ggml_view_1d(ctx0, pos_q, n_el, 0);
7688+ pos_q = ggml_cont(ctx0, pos_q);
7689+ }
7690+ if( !ggml_is_vector(pos_q) ) {
7691+ }
7692+
7693+
7694+ // apply rope
7695+ Qcur = ggml_rope_ext(ctx0, Qcur, pos_q, nullptr,
7696+ n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
7697+ ext_factor, attn_factor, beta_fast, beta_slow);
7698+
7699+ if( !ggml_is_vector(pos_k) ) {
7700+ const int64_t n_el = ggml_nelements(pos_k);
7701+ pos_k = ggml_view_1d(ctx0, pos_k, n_el, 0);
7702+ pos_k = ggml_cont(ctx0, pos_k);
7703+ }
7704+
7705+ K_work = ggml_rope_ext(ctx0, K_work, pos_k, nullptr,
7706+ n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
76237707 ext_factor, attn_factor, beta_fast, beta_slow);
76247708
7709+ // choseing mask, global vs swa
7710+ ggml_tensor * kq_b_layer = is_global ? inp_attn->self_kq_mask : inp_attn->self_kq_mask_swa;
7711+
76257712 ggml_tensor * attn_out = build_attn(
76267713 inp_attn,
7627- model.layers[il].wo, model.layers[il].bo,
7628- Qcur, Kcur, Vcur,
7629- /*k cache*/ nullptr,
7630- /*v cache*/ nullptr,
7714+ model.layers[il].wo,
7715+ model.layers[il].bo,
7716+ Qcur,
7717+ K_work,
7718+ V_work,
7719+ kq_b_layer,
7720+ nullptr,
76317721 1.0f / sqrtf(float(n_embd_head)),
76327722 il
76337723 );
76347724
7725+ // residual addition
76357726 ggml_tensor * cur_attn = ggml_add(ctx0, attn_out, x);
76367727
7637- // optional subselect output tokens (inp_out_ids)
7728+ // optional output select
76387729 if (il == n_layer - 1 && inp_out_ids) {
76397730 cur_attn = ggml_get_rows(ctx0, cur_attn, inp_out_ids);
76407731 inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
76417732 }
76427733
7643- // pre mlp LayerNorm
7734+ // pre mlp layer norm
76447735 ggml_tensor * h = build_norm(cur_attn, model.layers[il].ffn_norm, model.layers[il].ffn_norm_b, LLM_NORM, il);
76457736
7646- // geglu FFN
7737+ // geglu ffn
76477738 ggml_tensor * mlp_out = build_ffn(
76487739 h,
76497740 model.layers[il].ffn_up, NULL, NULL,
@@ -7653,10 +7744,11 @@ struct llm_build_modern_bert : public llm_graph_context {
76537744 LLM_FFN_GEGLU, LLM_FFN_PAR, il
76547745 );
76557746
7656- // resid addition
7747+ // resudi addition after FFN
76577748 inpL = ggml_add(ctx0, mlp_out, cur_attn);
76587749 }
76597750
7751+
76607752 ggml_tensor * cur = build_norm(inpL, model.output_norm, model.output_norm_b, LLM_NORM, -1);
76617753 res->t_embd = cur;
76627754 ggml_build_forward_expand(gf, cur);
0 commit comments