@@ -15677,28 +15677,7 @@ struct llm_build_plamo2 : public llm_graph_context_mamba {
1567715677 ext_factor, attn_factor, beta_fast, beta_slow
1567815678 );
1567915679
15680- // Store original K and V for KV cache (before GQA expansion)
15681- ggml_tensor * Kcur_cache = Kcur;
15682- ggml_tensor * Vcur_cache = Vcur;
15683-
15684- // PLaMo-2 GQA: expand K and V heads to match Q heads (equivalent to _expand_kv)
15685- if (n_head_kv < n_head) {
15686- // const int n_group = n_head / n_head_kv;
15687-
15688- // manually expand K and V tensors to repeat each head n_group times
15689- // create expanded tensors with target dimensions
15690- ggml_tensor * Kcur_expanded = ggml_new_tensor_3d(ctx0, Kcur->type, n_embd_head_k, n_head, n_tokens);
15691- ggml_tensor * Vcur_expanded = ggml_new_tensor_3d(ctx0, Vcur->type, n_embd_head_v, n_head, n_tokens);
15692-
15693- // repeat each head n_group times
15694- Kcur = ggml_repeat(ctx0, Kcur, Kcur_expanded);
15695- Vcur = ggml_repeat(ctx0, Vcur, Vcur_expanded);
15696-
15697- cb(Kcur, "Kcur_expanded", il);
15698- cb(Vcur, "Vcur_expanded", il);
15699- }
15700-
15701- cur = build_attn(inp, gf, model.layers[il].wo, NULL, Qcur, Kcur_cache, Vcur_cache, NULL, NULL, 1.0f, il);
15680+ cur = build_attn(inp, gf, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f, il);
1570215681 }
1570315682
1570415683 cb(cur, "attn_out", il);
0 commit comments