@@ -6319,57 +6319,56 @@ struct llm_build_neo_bert : public llm_graph_context {
63196319
63206320        auto * inp_attn = build_attn_inp_no_cache();
63216321
6322-         ggml_tensor * inp_out_ids = build_inp_out_ids();
6323- 
6322+         // iterate layers
63246323        for (int il = 0; il < n_layer; ++il) {
63256324            ggml_tensor * cur = inpL;
63266325
6326+             ggml_tensor * Qcur;
6327+             ggml_tensor * Kcur;
6328+             ggml_tensor * Vcur;
6329+ 
63276330            // pre-norm
63286331            cur = build_norm(inpL,
63296332                    model.layers[il].attn_norm, NULL,
63306333                    LLM_NORM_RMS, il);
63316334
6332-             {
6333-                 ggml_tensor * Qcur;
6334-                 ggml_tensor * Kcur;
6335-                 ggml_tensor * Vcur;
6336- 
6337-                 // self-attention
6338-                 cur = build_lora_mm(model.layers[il].wqkv, cur);
6339-                 cb(cur, "wqkv", il);
6340- 
6341-                 Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
6342-                 Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
6343-                 Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
6344- 
6345-                 Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
6346-                 Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
6347-                 Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
6348- 
6349-                 // RoPE
6350-                 Qcur = ggml_rope_ext(
6351-                         ctx0, Qcur, inp_pos, nullptr,
6352-                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
6353-                         ext_factor, attn_factor, beta_fast, beta_slow
6354-                         );
6335+             // self-attention
6336+             cur = build_lora_mm(model.layers[il].wqkv, cur);
6337+             cb(cur, "wqkv", il);
6338+ 
6339+             Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
6340+             Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
6341+             Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
6342+ 
6343+             Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
6344+             Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
6345+             Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
6346+ 
6347+             // RoPE
6348+             Qcur = ggml_rope_ext(
6349+                     ctx0, Qcur, inp_pos, nullptr,
6350+                     n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
6351+                     ext_factor, attn_factor, beta_fast, beta_slow
6352+                     );
63556353
6356-                  Kcur = ggml_rope_ext(
6357-                          ctx0, Kcur, inp_pos, nullptr,
6358-                          n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
6359-                          ext_factor, attn_factor, beta_fast, beta_slow
6360-                          );
6354+             Kcur = ggml_rope_ext(
6355+                     ctx0, Kcur, inp_pos, nullptr,
6356+                     n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
6357+                     ext_factor, attn_factor, beta_fast, beta_slow
6358+                     );
63616359
6362-                  cb(Qcur, "Qcur", il);
6363-                  cb(Kcur, "Kcur", il);
6364-                  cb(Vcur, "Vcur", il);
6360+             cb(Qcur, "Qcur", il);
6361+             cb(Kcur, "Kcur", il);
6362+             cb(Vcur, "Vcur", il);
63656363
6366-                 cur = build_attn(inp_attn, gf,
6367-                         model.layers[il].wo, nullptr,
6368-                         Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6369-                 cb(cur, "kqv_out", il);
6370-             }
6364+             cur = build_attn(inp_attn, gf,
6365+                     model.layers[il].wo, nullptr,
6366+                     Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6367+             cb(cur, "kqv_out", il);
63716368
6372-             if (il == n_layer - 1 && inp_out_ids) {
6369+             if (il == n_layer - 1 && pooling_type == LLAMA_POOLING_TYPE_NONE) {
6370+                 // skip computing output for unused tokens
6371+                 ggml_tensor * inp_out_ids = build_inp_out_ids();
63736372                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
63746373                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
63756374            }
@@ -14798,6 +14797,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
1479814797        case LLM_ARCH_GRANITE_MOE:
1479914798        case LLM_ARCH_CHAMELEON:
1480014799        case LLM_ARCH_BAILINGMOE:
14800+         case LLM_ARCH_NEO_BERT:
1480114801        case LLM_ARCH_ARCEE:
1480214802            return LLAMA_ROPE_TYPE_NORM;
1480314803
0 commit comments