From bb1202b281c67a71158202d4ac47b6cc87fe4470 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 29 Aug 2025 13:31:02 +0300 Subject: [PATCH 1/7] model : avoid ggml_cont_3d for fused QKV weights ggml-ci --- src/llama-kv-cache.cpp | 20 ++++++- src/llama-model.cpp | 128 ++++++++++++----------------------------- 2 files changed, 54 insertions(+), 94 deletions(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index ae35f74201e..c3e9bc74a89 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -1020,9 +1020,14 @@ ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggm auto * k = layers[ikv].k; + const int64_t n_embd_k_gqa = k_cur->ne[0]*k_cur->ne[1]; const int64_t n_tokens = k_cur->ne[2]; - k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens); + // we can merge dims 0 and 1 + assert(k_cur->nb[0]*k_cur->ne[0] == k_cur->nb[1]); + + //k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens); + k_cur = ggml_view_2d(ctx, k_cur, n_embd_k_gqa, n_tokens, k_cur->nb[2], 0); if (k->ne[2] > 1) { k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1]*k->ne[2]); @@ -1041,9 +1046,14 @@ ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggm const int64_t n_embd_v_gqa = v_cur->ne[0]*v_cur->ne[1]; const int64_t n_tokens = v_cur->ne[2]; - v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens); + // we can merge dims 0 and 1 + assert(v_cur->nb[0]*v_cur->ne[0] == v_cur->nb[1]); + + //v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens); if (!v_trans) { + v_cur = ggml_view_2d(ctx, v_cur, n_embd_v_gqa, n_tokens, v_cur->nb[2], 0); + if (v->ne[2] > 1) { v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1]*v->ne[2]); } @@ -1051,6 +1061,12 @@ ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggm return ggml_set_rows(ctx, v, v_cur, v_idxs); } + if (v_cur->nb[1]*v_cur->ne[1] != v_cur->nb[2]) { + v_cur = ggml_cont_2d (ctx, v_cur, n_embd_v_gqa, n_tokens); + } else { + v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens); + } + // [TAG_V_CACHE_VARIABLE] if (n_embd_v_gqa < v->ne[0]) { v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_v_gqa, 0, 0, 0); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 1813f06d7b3..b9e4634a706 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -6927,9 +6927,7 @@ struct llm_build_falcon : public llm_graph_context { ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - - Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); // using mode = 2 for neox mode Qcur = ggml_rope_ext( @@ -7207,9 +7205,7 @@ struct llm_build_dbrx : public llm_graph_context { Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - - Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -7329,13 +7325,9 @@ struct llm_build_starcoder : public llm_graph_context { cur = ggml_add(ctx0, cur, model.layers[il].bqkv); cb(cur, "bqkv", il); - ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)); - ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - - Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_cont_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); + ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); + ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); @@ -7551,14 +7543,16 @@ struct llm_build_bert : public llm_graph_context { cb(cur, "bqkv", il); } - Qcur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)); - Kcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); + Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); + Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); } else { Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, cur), model.layers[il].bq); Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, cur), model.layers[il].bk); Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, cur), model.layers[il].bv); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); } @@ -7569,8 +7563,6 @@ struct llm_build_bert : public llm_graph_context { LLM_NORM, il); Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - } else { - Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); } if (model.layers[il].attn_k_norm) { @@ -7580,8 +7572,6 @@ struct llm_build_bert : public llm_graph_context { LLM_NORM, il); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - } else { - Kcur = ggml_cont_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); } // RoPE @@ -7727,9 +7717,7 @@ struct llm_build_neo_bert : public llm_graph_context { Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - - Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); // RoPE Qcur = ggml_rope_ext( @@ -7836,13 +7824,9 @@ struct llm_build_bloom : public llm_graph_context { cur = ggml_add(ctx0, cur, model.layers[il].bqkv); cb(cur, "bqkv", il); - ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)); - ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - - Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_cont_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); + ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); + ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); @@ -7958,13 +7942,9 @@ struct llm_build_mpt : public llm_graph_context { cb(cur, "wqkv_clamped", il); } - ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)); - ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); + ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); + ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); + ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); // Q/K Layernorm if (model.layers[il].attn_q_norm) { @@ -7972,26 +7952,16 @@ struct llm_build_mpt : public llm_graph_context { model.layers[il].attn_q_norm, model.layers[il].attn_q_norm_b, LLM_NORM, il); - cb(Qcur, "Qcur", il); Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, model.layers[il].attn_k_norm_b, LLM_NORM, il); - cb(Kcur, "Kcur", il); Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - } else { - Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - cb(Qcur, "Qcur", il); - - Kcur = ggml_cont_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - cb(Kcur, "Kcur", il); } - Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); @@ -8240,11 +8210,9 @@ struct llm_build_qwen : public llm_graph_context { cur = ggml_add(ctx0, cur, model.layers[il].bqkv); cb(cur, "bqkv", il); - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); + ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 2*sizeof(float)*(n_embd)); - - Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 2*sizeof(float)*(n_embd)); // using mode = 2 for neox mode Qcur = ggml_rope_ext( @@ -9219,21 +9187,17 @@ struct llm_build_phi2 : public llm_graph_context { Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); } else { Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, attn_norm_output), model.layers[il].bq); Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, attn_norm_output), model.layers[il].bk); Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, attn_norm_output), model.layers[il].bv); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); } - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -9357,21 +9321,17 @@ struct llm_build_phi3 : public llm_graph_context { Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head * sizeof(float), cur->nb[1], 0 * sizeof(float) * (n_embd)); Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), cur->nb[1], 1 * sizeof(float) * (n_embd)); - Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)); - Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)); } else { Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, attn_norm_output), model.layers[il].bq); Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, attn_norm_output), model.layers[il].bk); Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, attn_norm_output), model.layers[il].bv); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); } - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -9621,18 +9581,14 @@ struct llm_build_gpt2 : public llm_graph_context { cur = ggml_add(ctx0, cur, model.layers[il].bqkv); cb(cur, "bqkv", il); - ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)); - ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); + ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); + ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); + ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_cont_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); @@ -9727,9 +9683,7 @@ struct llm_build_codeshell : public llm_graph_context { ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - - Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -12601,9 +12555,7 @@ struct llm_build_gptneox : public llm_graph_context { ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - - Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -13736,18 +13688,14 @@ struct llm_build_jais : public llm_graph_context { cur = ggml_add(ctx0, cur, model.layers[il].bqkv); cb(cur, "bqkv", il); - ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*cur->nb[0]*(n_embd)); - ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*cur->nb[0]*(n_embd)); - ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*cur->nb[0]*(n_embd + n_embd_gqa)); + ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*cur->nb[0]*(n_embd)); + ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*cur->nb[0]*(n_embd)); + ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*cur->nb[0]*(n_embd + n_embd_gqa)); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_cont_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/float(n_embd_head), il); @@ -13859,8 +13807,7 @@ struct llm_build_chatglm : public llm_graph_context { } Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); } //printf("freq_base: %f freq_scale: %f ext_factor: %f attn_factor: %f\n", freq_base, freq_scale, ext_factor, attn_factor); @@ -13993,8 +13940,7 @@ struct llm_build_glm4 : public llm_graph_context { } Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); } Qcur = ggml_rope_ext( @@ -17293,16 +17239,14 @@ struct llm_build_plamo2 : public llm_graph_context_mamba { const int64_t k_offset = n_embd_head_q * n_head; const int64_t v_offset = k_offset + n_embd_head_k * n_head_kv; - ggml_tensor * Qcur = ggml_view_3d(ctx0, qkv, n_embd_head_q, n_head, n_tokens, n_embd_head_q * sizeof(float), qkv->nb[1], q_offset * ggml_element_size(qkv)); + ggml_tensor * Qcur = ggml_view_3d(ctx0, qkv, n_embd_head_q, n_head, n_tokens, n_embd_head_q * sizeof(float), qkv->nb[1], q_offset * ggml_element_size(qkv)); ggml_tensor * Kcur = ggml_view_3d(ctx0, qkv, n_embd_head_k, n_head_kv, n_tokens, n_embd_head_k * sizeof(float), qkv->nb[1], k_offset * ggml_element_size(qkv)); - ggml_tensor * Vcur = ggml_view_2d(ctx0, qkv, n_embd_head_v * n_head_kv, n_tokens, qkv->nb[1], v_offset * ggml_element_size(qkv)); + ggml_tensor * Vcur = ggml_view_3d(ctx0, qkv, n_embd_head_v, n_head_kv, n_tokens, n_embd_head_v * sizeof(float), qkv->nb[1], v_offset * ggml_element_size(qkv)); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head_v, n_head_kv, n_tokens); - Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); From 85a5ea3671b9047aa9d8420e1454d152b599c164 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 8 Sep 2025 09:39:00 +0300 Subject: [PATCH 2/7] kv-cache : make cpy_k and cpy_v implementation more readable ggml-ci --- src/llama-kv-cache.cpp | 70 ++++++++++++++++++++++++++++-------------- 1 file changed, 47 insertions(+), 23 deletions(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index c3e9bc74a89..d09cc1d8dba 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -1018,21 +1018,33 @@ ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggm const int32_t ikv = map_layer_ids.at(il); - auto * k = layers[ikv].k; + ggml_tensor * k = layers[ikv].k; + + const int64_t n_embd_head = k_cur->ne[0]; + const int64_t n_head = k_cur->ne[1]; + const int64_t n_token = k_cur->ne[2]; - const int64_t n_embd_k_gqa = k_cur->ne[0]*k_cur->ne[1]; - const int64_t n_tokens = k_cur->ne[2]; + const int64_t n_embd_gqa = n_embd_head*n_head; // we can merge dims 0 and 1 - assert(k_cur->nb[0]*k_cur->ne[0] == k_cur->nb[1]); + // TODO: add ggml helper function for this? + assert(ggml_row_size(k_cur->type, n_embd_head) == k_cur->nb[1]); + + k_cur = ggml_view_2d(ctx, k_cur, n_embd_gqa, n_token , k_cur->nb[2], 0); + + const int64_t n_stream = k->ne[2]; + + if (n_stream > 1) { + const uint64_t kv_size = get_size(); - //k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens); - k_cur = ggml_view_2d(ctx, k_cur, n_embd_k_gqa, n_tokens, k_cur->nb[2], 0); + assert(n_embd_gqa == k->ne[0]); + assert(kv_size == k->ne[1]); - if (k->ne[2] > 1) { - k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1]*k->ne[2]); + // merge the buffer across all streams because the idxs are global + k = ggml_reshape_2d(ctx, k, n_embd_gqa, kv_size*n_stream); } + // store the current K values into the cache return ggml_set_rows(ctx, k, k_cur, k_idxs); } @@ -1043,39 +1055,51 @@ ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggm auto * v = layers[ikv].v; - const int64_t n_embd_v_gqa = v_cur->ne[0]*v_cur->ne[1]; - const int64_t n_tokens = v_cur->ne[2]; + const int64_t n_embd_head = v_cur->ne[0]; + const int64_t n_head = v_cur->ne[1]; + const int64_t n_token = v_cur->ne[2]; + + const int64_t n_embd_gqa = n_embd_head*n_head; // we can merge dims 0 and 1 - assert(v_cur->nb[0]*v_cur->ne[0] == v_cur->nb[1]); + assert(ggml_row_size(v_cur->type, n_embd_head) == v_cur->nb[1]); - //v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens); + const int64_t n_stream = v->ne[2]; + // take this branch when FA is enabled (the V cache is not transposed) if (!v_trans) { - v_cur = ggml_view_2d(ctx, v_cur, n_embd_v_gqa, n_tokens, v_cur->nb[2], 0); + v_cur = ggml_view_2d(ctx, v_cur, n_embd_gqa, n_token, v_cur->nb[2], 0); + + if (n_stream > 1) { + const uint64_t kv_size = get_size(); + + assert(n_embd_gqa == v->ne[0]); + assert(kv_size == v->ne[1]); - if (v->ne[2] > 1) { - v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1]*v->ne[2]); + // merge the buffer across all streams because the idxs are global + v = ggml_reshape_2d(ctx, v, n_embd_gqa, kv_size*n_stream); } return ggml_set_rows(ctx, v, v_cur, v_idxs); } - if (v_cur->nb[1]*v_cur->ne[1] != v_cur->nb[2]) { - v_cur = ggml_cont_2d (ctx, v_cur, n_embd_v_gqa, n_tokens); + if (ggml_row_size(v_cur->type, n_embd_gqa) == v_cur->nb[2]) { + // we can merge dims 0, 1 and 2 + v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_gqa, n_token); } else { - v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens); + // otherwise -> make a copy to get contiguous data + v_cur = ggml_cont_2d (ctx, v_cur, n_embd_gqa, n_token); } // [TAG_V_CACHE_VARIABLE] - if (n_embd_v_gqa < v->ne[0]) { - v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_v_gqa, 0, 0, 0); + if (n_embd_gqa < v->ne[0]) { + v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_gqa, 0, 0, 0); } - // the row becomes a single element - ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, v->ne[0]*v->ne[1]*v->ne[2]); + // in this branch the v_idxs are constructed in such a way that each row becomes a single element + ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, ggml_nelements(v)); - v_cur = ggml_reshape_2d(ctx, v_cur, 1, v_cur->ne[0]*v_cur->ne[1]); + v_cur = ggml_reshape_2d(ctx, v_cur, 1, ggml_nelements(v_cur)); return ggml_set_rows(ctx, v_view, v_cur, v_idxs); } From 3dec397b92eee9f4c009b70c88cc69f9617d3b77 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 8 Sep 2025 09:46:55 +0300 Subject: [PATCH 3/7] cont : add comments ggml-ci --- src/llama-kv-cache.cpp | 12 ++++++------ src/llama-kv-cache.h | 7 +++++++ 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index d09cc1d8dba..5a1251bb36a 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -1022,7 +1022,7 @@ ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggm const int64_t n_embd_head = k_cur->ne[0]; const int64_t n_head = k_cur->ne[1]; - const int64_t n_token = k_cur->ne[2]; + const int64_t n_tokens = k_cur->ne[2]; const int64_t n_embd_gqa = n_embd_head*n_head; @@ -1030,7 +1030,7 @@ ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggm // TODO: add ggml helper function for this? assert(ggml_row_size(k_cur->type, n_embd_head) == k_cur->nb[1]); - k_cur = ggml_view_2d(ctx, k_cur, n_embd_gqa, n_token , k_cur->nb[2], 0); + k_cur = ggml_view_2d(ctx, k_cur, n_embd_gqa, n_tokens, k_cur->nb[2], 0); const int64_t n_stream = k->ne[2]; @@ -1057,7 +1057,7 @@ ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggm const int64_t n_embd_head = v_cur->ne[0]; const int64_t n_head = v_cur->ne[1]; - const int64_t n_token = v_cur->ne[2]; + const int64_t n_tokens = v_cur->ne[2]; const int64_t n_embd_gqa = n_embd_head*n_head; @@ -1068,7 +1068,7 @@ ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggm // take this branch when FA is enabled (the V cache is not transposed) if (!v_trans) { - v_cur = ggml_view_2d(ctx, v_cur, n_embd_gqa, n_token, v_cur->nb[2], 0); + v_cur = ggml_view_2d(ctx, v_cur, n_embd_gqa, n_tokens, v_cur->nb[2], 0); if (n_stream > 1) { const uint64_t kv_size = get_size(); @@ -1085,10 +1085,10 @@ ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggm if (ggml_row_size(v_cur->type, n_embd_gqa) == v_cur->nb[2]) { // we can merge dims 0, 1 and 2 - v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_gqa, n_token); + v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_gqa, n_tokens); } else { // otherwise -> make a copy to get contiguous data - v_cur = ggml_cont_2d (ctx, v_cur, n_embd_gqa, n_token); + v_cur = ggml_cont_2d (ctx, v_cur, n_embd_gqa, n_tokens); } // [TAG_V_CACHE_VARIABLE] diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index b545bf5b9cf..d7b14e05ad6 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -317,9 +317,16 @@ class llama_kv_cache_context : public llama_memory_context_i { ggml_tensor * get_v(ggml_context * ctx, int32_t il) const; // store k_cur and v_cur in the cache based on the provided head location + // - k_cur [n_embd_head_k, n_head_k, n_tokens] + // - k_idxs [n_batch] + // - v_cur [n_embd_head_v, n_head_v, n_tokens] + // - v_idxs [n_batch] or [n_batch*n_embd_v_gqa] depending if V cache is transposed ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const; ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const; + // create destination indices for each head of the current batch for where it would be written in the KV cache + // the indices address the global KV cache (not per stream) - this is not relevant for the user of this API, but + // helps understand the implementation logic of cpy_k and cpy_v ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; From c62c354f4f517abf2c4d29f7a54c6e2c3ac5bc0f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 8 Sep 2025 09:49:05 +0300 Subject: [PATCH 4/7] cont : minor fix [no ci] --- src/llama-kv-cache.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index d7b14e05ad6..c2abda1ad27 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -318,9 +318,9 @@ class llama_kv_cache_context : public llama_memory_context_i { // store k_cur and v_cur in the cache based on the provided head location // - k_cur [n_embd_head_k, n_head_k, n_tokens] - // - k_idxs [n_batch] + // - k_idxs [n_tokens] // - v_cur [n_embd_head_v, n_head_v, n_tokens] - // - v_idxs [n_batch] or [n_batch*n_embd_v_gqa] depending if V cache is transposed + // - v_idxs [n_tokens] or [n_tokens*n_embd_v_gqa] depending if V cache is transposed ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const; ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const; From 1efa9e8a5d4a3ce97d05208a78aae564a9d23a18 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 8 Sep 2025 09:50:01 +0300 Subject: [PATCH 5/7] cont : one more fix --- src/llama-kv-cache.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 5a1251bb36a..3f8c2638020 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -1035,7 +1035,7 @@ ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggm const int64_t n_stream = k->ne[2]; if (n_stream > 1) { - const uint64_t kv_size = get_size(); + const int64_t kv_size = get_size(); assert(n_embd_gqa == k->ne[0]); assert(kv_size == k->ne[1]); @@ -1071,7 +1071,7 @@ ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggm v_cur = ggml_view_2d(ctx, v_cur, n_embd_gqa, n_tokens, v_cur->nb[2], 0); if (n_stream > 1) { - const uint64_t kv_size = get_size(); + const int64_t kv_size = get_size(); assert(n_embd_gqa == v->ne[0]); assert(kv_size == v->ne[1]); From d6be191d673cdb5cbbbf8ad8a17b29f2f10129dd Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 8 Sep 2025 09:53:40 +0300 Subject: [PATCH 6/7] cont : clarity ggml-ci --- src/llama-kv-cache.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 3f8c2638020..3a20da39354 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -1096,7 +1096,7 @@ ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggm v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_gqa, 0, 0, 0); } - // in this branch the v_idxs are constructed in such a way that each row becomes a single element + // in this branch the v_idxs are constructed in such a way that each row is a single head element ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, ggml_nelements(v)); v_cur = ggml_reshape_2d(ctx, v_cur, 1, ggml_nelements(v_cur)); From 60d6e7c6fd8bacac0892b8722f5d5c585139cb43 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 8 Sep 2025 09:57:01 +0300 Subject: [PATCH 7/7] kv-cache : require contiguous heads of k_cur and v_cur ggml-ci --- src/llama-kv-cache.cpp | 4 ++-- src/llama-kv-cache.h | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 3a20da39354..885be072a75 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -1028,7 +1028,7 @@ ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggm // we can merge dims 0 and 1 // TODO: add ggml helper function for this? - assert(ggml_row_size(k_cur->type, n_embd_head) == k_cur->nb[1]); + GGML_ASSERT(ggml_row_size(k_cur->type, n_embd_head) == k_cur->nb[1]); k_cur = ggml_view_2d(ctx, k_cur, n_embd_gqa, n_tokens, k_cur->nb[2], 0); @@ -1062,7 +1062,7 @@ ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggm const int64_t n_embd_gqa = n_embd_head*n_head; // we can merge dims 0 and 1 - assert(ggml_row_size(v_cur->type, n_embd_head) == v_cur->nb[1]); + GGML_ASSERT(ggml_row_size(v_cur->type, n_embd_head) == v_cur->nb[1]); const int64_t n_stream = v->ne[2]; diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index c2abda1ad27..30de013f5f7 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -317,6 +317,7 @@ class llama_kv_cache_context : public llama_memory_context_i { ggml_tensor * get_v(ggml_context * ctx, int32_t il) const; // store k_cur and v_cur in the cache based on the provided head location + // note: the heads in k_cur and v_cur should be layed out contiguously in memory // - k_cur [n_embd_head_k, n_head_k, n_tokens] // - k_idxs [n_tokens] // - v_cur [n_embd_head_v, n_head_v, n_tokens]