Skip to content

Commit 184ce30

Browse files
authored
merge even more conts and reshapes
1 parent 20b433a commit 184ce30

File tree

1 file changed

+32
-25
lines changed

1 file changed

+32
-25
lines changed

src/llama-model.cpp

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7367,32 +7367,38 @@ struct llm_build_bert : public llm_graph_context {
73677367
cb(cur, "bqkv", il);
73687368
}
73697369

7370-
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
7371-
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
7372-
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
7370+
Qcur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd));
7371+
Kcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd));
7372+
Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
7373+
Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
73737374
} else {
73747375
Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, cur), model.layers[il].bq);
73757376
Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, cur), model.layers[il].bk);
73767377
Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, cur), model.layers[il].bv);
7378+
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
73777379
}
73787380

73797381
if (model.layers[il].attn_q_norm) {
73807382
Qcur = build_norm(Qcur,
73817383
model.layers[il].attn_q_norm,
73827384
model.layers[il].attn_q_norm_b,
73837385
LLM_NORM, il);
7386+
7387+
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
7388+
} else {
7389+
Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
73847390
}
73857391

73867392
if (model.layers[il].attn_k_norm) {
73877393
Kcur = build_norm(Kcur,
73887394
model.layers[il].attn_k_norm,
73897395
model.layers[il].attn_k_norm_b,
73907396
LLM_NORM, il);
7391-
}
73927397

7393-
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
7394-
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
7395-
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
7398+
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
7399+
} else {
7400+
Kcur = ggml_cont_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
7401+
}
73967402

73977403
// RoPE
73987404
if (model.arch == LLM_ARCH_NOMIC_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE) {
@@ -7770,7 +7776,7 @@ struct llm_build_mpt : public llm_graph_context {
77707776

77717777
ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd));
77727778
ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd));
7773-
ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
7779+
ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
77747780

77757781
cb(Qcur, "Qcur", il);
77767782
cb(Kcur, "Kcur", il);
@@ -7789,17 +7795,18 @@ struct llm_build_mpt : public llm_graph_context {
77897795
model.layers[il].attn_k_norm_b,
77907796
LLM_NORM, il);
77917797
cb(Kcur, "Kcur", il);
7798+
7799+
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
7800+
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
77927801
} else {
7793-
Qcur = ggml_cont(ctx0, Qcur);
7802+
Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
77947803
cb(Qcur, "Qcur", il);
77957804

7796-
Kcur = ggml_cont(ctx0, Kcur);
7805+
Kcur = ggml_cont_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
77977806
cb(Kcur, "Kcur", il);
77987807
}
77997808

7800-
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
7801-
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
7802-
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
7809+
Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
78037810

78047811
cb(Qcur, "Qcur", il);
78057812
cb(Kcur, "Kcur", il);
@@ -9026,21 +9033,21 @@ struct llm_build_phi2 : public llm_graph_context {
90269033

90279034
Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
90289035
Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
9029-
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
9036+
Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
9037+
Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
90309038
} else {
90319039
Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, attn_norm_output), model.layers[il].bq);
90329040
Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, attn_norm_output), model.layers[il].bk);
90339041
Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, attn_norm_output), model.layers[il].bv);
90349042
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
90359043
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
9044+
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
90369045
}
90379046

90389047
cb(Qcur, "Qcur", il);
90399048
cb(Kcur, "Kcur", il);
90409049
cb(Vcur, "Vcur", il);
90419050

9042-
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
9043-
90449051
Qcur = ggml_rope_ext(
90459052
ctx0, Qcur, inp_pos, nullptr,
90469053
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
@@ -9164,21 +9171,21 @@ struct llm_build_phi3 : public llm_graph_context {
91649171

91659172
Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head * sizeof(float), cur->nb[1], 0 * sizeof(float) * (n_embd));
91669173
Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), cur->nb[1], 1 * sizeof(float) * (n_embd));
9167-
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)));
9174+
Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa));
9175+
Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
91689176
} else {
91699177
Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, attn_norm_output), model.layers[il].bq);
91709178
Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, attn_norm_output), model.layers[il].bk);
91719179
Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, attn_norm_output), model.layers[il].bv);
91729180
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
91739181
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
9182+
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
91749183
}
91759184

91769185
cb(Qcur, "Qcur", il);
91779186
cb(Kcur, "Kcur", il);
91789187
cb(Vcur, "Vcur", il);
91799188

9180-
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
9181-
91829189
Qcur = ggml_rope_ext(
91839190
ctx0, Qcur, inp_pos, rope_factors,
91849191
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
@@ -13526,6 +13533,7 @@ struct llm_build_chatglm : public llm_graph_context {
1352613533
}
1352713534
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
1352813535
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
13536+
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
1352913537
} else {
1353013538
cur = build_lora_mm(model.layers[il].wqkv, cur);
1353113539
cb(cur, "wqkv", il);
@@ -13535,11 +13543,10 @@ struct llm_build_chatglm : public llm_graph_context {
1353513543
}
1353613544
Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
1353713545
Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
13538-
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
13546+
Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
13547+
Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
1353913548
}
1354013549

13541-
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
13542-
1354313550
//printf("freq_base: %f freq_scale: %f ext_factor: %f attn_factor: %f\n", freq_base, freq_scale, ext_factor, attn_factor);
1354413551
Qcur = ggml_rope_ext(
1354513552
ctx0, Qcur, inp_pos, nullptr,
@@ -13660,6 +13667,7 @@ struct llm_build_glm4 : public llm_graph_context {
1366013667
}
1366113668
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
1366213669
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
13670+
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
1366313671
} else {
1366413672
cur = build_lora_mm(model.layers[il].wqkv, cur);
1366513673
cb(cur, "wqkv", il);
@@ -13669,11 +13677,10 @@ struct llm_build_glm4 : public llm_graph_context {
1366913677
}
1367013678
Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
1367113679
Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
13672-
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
13680+
Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
13681+
Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
1367313682
}
1367413683

13675-
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
13676-
1367713684
Qcur = ggml_rope_ext(
1367813685
ctx0, Qcur, inp_pos, nullptr,
1367913686
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,

0 commit comments

Comments
 (0)