Skip to content

Commit 1a3aaa3

Browse files
ikawrakowIwan Kawrakow
andauthored
Merge Q and K into a single tensor (ikawrakow#892)
* Merge Q and K into a single tensor * Make V mul mat follow QK mul mat so they can be fused, which gives a slightly bbetter TG performance. --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent abb966e commit 1a3aaa3

File tree

4 files changed

+81
-1
lines changed

4 files changed

+81
-1
lines changed

src/llama-build-context.cpp

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1270,6 +1270,7 @@ std::tuple<ggml_tensor*, ggml_tensor*, ggml_tensor*> llm_build_context::llm_buil
12701270

12711271
std::tuple<ggml_tensor*, ggml_tensor*, ggml_tensor*> llm_build_context::llm_build_mul_mat_qkv(ggml_cgraph * gf, ggml_tensor * cur,
12721272
ggml_tensor * wqkv, ggml_tensor * bqkv,
1273+
ggml_tensor * wqk, ggml_tensor * bqk,
12731274
ggml_tensor * wq, ggml_tensor * bq,
12741275
ggml_tensor * wk, ggml_tensor * bk,
12751276
ggml_tensor * wv, ggml_tensor * bv,
@@ -1307,6 +1308,40 @@ std::tuple<ggml_tensor*, ggml_tensor*, ggml_tensor*> llm_build_context::llm_buil
13071308
//ggml_build_forward_expand(gf, Vcur);
13081309
}
13091310

1311+
if (wqk) {
1312+
auto qk = llm_build_lora_mm(lctx, ctx0, wqk, cur);
1313+
cb(qk, "qkv", il);
1314+
if (bqk) {
1315+
qk = ggml_add(ctx0, qk, bqk);
1316+
cb(qk, "qkv_b", il);
1317+
}
1318+
auto Vcur = llm_build_lora_mm(lctx, ctx0, wv, cur);
1319+
cb(Vcur, "Vcur", il);
1320+
if (bv) {
1321+
Vcur = ggml_add(ctx0, Vcur, bv);
1322+
cb(Vcur, "Vcur", il);
1323+
}
1324+
ggml_build_forward_expand(gf, qk);
1325+
ggml_build_forward_expand(gf, Vcur);
1326+
auto Qcur = ggml_view_3d(ctx0, qk, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), qk->nb[1], 0*sizeof(float)*(n_embd));
1327+
auto Kcur = ggml_view_3d(ctx0, qk, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), qk->nb[1], 1*sizeof(float)*Qcur->ne[0]*Qcur->ne[1]);
1328+
cb(Qcur, "Qcur", il);
1329+
cb(Kcur, "Kcur", il);
1330+
if (q_norm) {
1331+
Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, cb, il);
1332+
cb(Qcur, "Qcur_normed", il);
1333+
ggml_build_forward_expand(gf, Qcur);
1334+
}
1335+
if (k_norm) {
1336+
Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, cb, il);
1337+
cb(Kcur, "Kcur_normed", il);
1338+
ggml_build_forward_expand(gf, Kcur);
1339+
}
1340+
1341+
return {Qcur, Kcur, Vcur};
1342+
1343+
}
1344+
13101345
auto [Q, K, V] = llm_build_mul_mat_qkv(gf, cur, wq, bq, wk, bk, wv, bv, attention_scale, il);
13111346
auto Qcur = ggml_reshape_3d(ctx0, Q, n_embd_head, n_head, n_tokens);
13121347
if (q_norm) {
@@ -1374,6 +1409,7 @@ ggml_cgraph * llm_build_context::build_llama() {
13741409

13751410
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur,
13761411
model.layers[il].wqkv, model.layers[il].bqkv,
1412+
model.layers[il].wqk, model.layers[il].bqk,
13771413
model.layers[il].wq, model.layers[il].bq,
13781414
model.layers[il].wk, model.layers[il].bk,
13791415
model.layers[il].wv, model.layers[il].bv,
@@ -3400,6 +3436,7 @@ ggml_cgraph * llm_build_context::build_qwen3() {
34003436
{
34013437
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur,
34023438
model.layers[il].wqkv, nullptr,
3439+
model.layers[il].wqk, nullptr,
34033440
model.layers[il].wq, nullptr,
34043441
model.layers[il].wk, nullptr,
34053442
model.layers[il].wv, nullptr,
@@ -3502,6 +3539,7 @@ ggml_cgraph * llm_build_context::build_qwen3moe() {
35023539
{
35033540
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur,
35043541
model.layers[il].wqkv, nullptr,
3542+
model.layers[il].wqk, nullptr,
35053543
model.layers[il].wq, nullptr, model.layers[il].wk, nullptr, model.layers[il].wv, nullptr,
35063544
model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, 0, il);
35073545

@@ -6403,6 +6441,7 @@ ggml_cgraph * llm_build_context::build_glm4_moe() {
64036441
{
64046442
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur,
64056443
model.layers[il].wqkv, model.layers[il].bqkv,
6444+
model.layers[il].wqk, model.layers[il].bqk,
64066445
model.layers[il].wq, model.layers[il].bq,
64076446
model.layers[il].wk, model.layers[il].bk,
64086447
model.layers[il].wv, model.layers[il].bv,
@@ -6814,6 +6853,7 @@ ggml_cgraph * llm_build_context::build_cohere2() {
68146853

68156854
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur,
68166855
model.layers[il].wqkv, model.layers[il].bqkv,
6856+
model.layers[il].wqk, model.layers[il].bqk,
68176857
model.layers[il].wq, model.layers[il].bq,
68186858
model.layers[il].wk, model.layers[il].bk,
68196859
model.layers[il].wv, model.layers[il].bv, nullptr, nullptr, 0.f, il);
@@ -8116,6 +8156,7 @@ ggml_cgraph * llm_build_context::build_openai_moe() {
81168156
{
81178157
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur,
81188158
model.layers[il].wqkv, model.layers[il].bqkv,
8159+
model.layers[il].wqk, model.layers[il].bqk,
81198160
model.layers[il].wq, model.layers[il].bq,
81208161
model.layers[il].wk, model.layers[il].bk,
81218162
model.layers[il].wv, model.layers[il].bv,
@@ -8234,7 +8275,7 @@ ggml_cgraph * llm_build_context::build_bailingmoe2() {
82348275
// self_attention
82358276
{
82368277
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, model.layers[il].wqkv, model.layers[il].bqkv,
8237-
nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
8278+
nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
82388279
model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, 0.0f, il);
82398280

82408281
if (rope_cache) {

src/llama-build-context.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ struct llm_build_context {
152152

153153
std::tuple<ggml_tensor*, ggml_tensor*, ggml_tensor*> llm_build_mul_mat_qkv(ggml_cgraph * gf, ggml_tensor * cur,
154154
ggml_tensor * wqkv, ggml_tensor * bqkv,
155+
ggml_tensor * wqk, ggml_tensor * bqk,
155156
ggml_tensor * wq, ggml_tensor * bq,
156157
ggml_tensor * wk, ggml_tensor * bk,
157158
ggml_tensor * wv, ggml_tensor * bv,

src/llama-load-tensors.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2495,6 +2495,40 @@ bool create_tensors_helper::merge_qkv(const LLM_TN & tn, int i, int bias) {
24952495
}
24962496
}
24972497
}
2498+
if (!fused_qkv && ml.merge_qkv && wq->type == wk->type && hparams.f_attention_scale == 0.0f) {
2499+
GGML_ASSERT(wq->ne[0] == n_embd && wq->ne[1] == n_head * n_embd_head_k);
2500+
GGML_ASSERT(wk->ne[0] == n_embd && wk->ne[1] == n_embd_gqa);
2501+
layer.wqk = ggml_new_tensor_2d(ctx_split, wq->type, n_embd, n_embd_head_k * (n_head + n_head_kv));
2502+
snprintf(layer.wqk->name, GGML_MAX_NAME, "blk.%d.attn_qk.weight", i);
2503+
layer.wq = ml.create_tensor_as_view(ctx_split, layer.wqk, wq_name.c_str(), { wq->ne[0], wq->ne[1] }, 0);
2504+
layer.wk = ml.create_tensor_as_view(ctx_split, layer.wqk, wk_name.c_str(), { wk->ne[0], wk->ne[1] }, wq->ne[1]*wq->nb[1]);
2505+
layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
2506+
printf("====================== Merged only Q and K in layer %d because V is of different type\n", i);
2507+
fused_qkv = true;
2508+
if (bias) {
2509+
auto bq_name = tn(LLM_TENSOR_ATTN_Q, "bias", i);
2510+
auto bk_name = tn(LLM_TENSOR_ATTN_K, "bias", i);
2511+
auto bv_name = tn(LLM_TENSOR_ATTN_V, "bias", i);
2512+
auto bq = ml.get_tensor_meta(bq_name.c_str());
2513+
auto bk = ml.get_tensor_meta(bk_name.c_str());
2514+
auto bv = ml.get_tensor_meta(bv_name.c_str());
2515+
if (bias == 2) {
2516+
GGML_ASSERT(bq && bk && bv);
2517+
} else {
2518+
GGML_ASSERT(!bq && !bk && !bv);
2519+
}
2520+
if (bq && bk && bv) {
2521+
GGML_ASSERT(bq->type == GGML_TYPE_F32 && bk->type == GGML_TYPE_F32);
2522+
GGML_ASSERT(ggml_nrows(bq) == 1 && bq->ne[0] == wq->ne[1]);
2523+
GGML_ASSERT(ggml_nrows(bk) == 1 && bk->ne[0] == wk->ne[1]);
2524+
layer.bqk = ggml_new_tensor_1d(ctx_layer, bq->type, n_embd_head_k * (n_head + n_head_kv));
2525+
snprintf(layer.bqk->name, GGML_MAX_NAME, "blk.%d.attn_qk.bias", i);
2526+
layer.bq = ml.create_tensor_as_view(ctx_layer, layer.bqk, bq_name.c_str(), { bq->ne[0] }, 0);
2527+
layer.bk = ml.create_tensor_as_view(ctx_layer, layer.bqk, bk_name.c_str(), { bk->ne[0] }, bq->ne[0]*bq->nb[0]);
2528+
layer.bv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {layer.wv->ne[1]});
2529+
}
2530+
}
2531+
}
24982532

24992533
if (!fused_qkv) {
25002534
if (ml.merge_qkv) {

src/llama-model.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,8 @@ struct llama_layer {
154154
struct ggml_tensor * wv = nullptr;
155155
struct ggml_tensor * wo = nullptr;
156156
struct ggml_tensor * wqkv = nullptr;
157+
struct ggml_tensor * wqk = nullptr;
158+
struct ggml_tensor * wkv = nullptr;
157159
struct ggml_tensor * wq_a = nullptr;
158160
struct ggml_tensor * wq_b = nullptr;
159161
struct ggml_tensor * wkv_a_mqa = nullptr;
@@ -176,6 +178,8 @@ struct llama_layer {
176178
struct ggml_tensor * bv = nullptr;
177179
struct ggml_tensor * bo = nullptr;
178180
struct ggml_tensor * bqkv = nullptr;
181+
struct ggml_tensor * bqk = nullptr;
182+
struct ggml_tensor * bkv = nullptr;
179183

180184
// relative position bias
181185
struct ggml_tensor * attn_rel_b = nullptr;

0 commit comments

Comments
 (0)