Skip to content

Commit de93218

Browse files
committed
Revert "graph : remove build_attn_with_sinks overload (ggml-org#15469)"
This reverts commit 3f196be.
1 parent 909c5c0 commit de93218

File tree

3 files changed

+133
-107
lines changed

3 files changed

+133
-107
lines changed

src/llama-graph.cpp

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1223,8 +1223,8 @@ ggml_tensor * llm_graph_context::build_attn_mha(
12231223
ggml_tensor * v,
12241224
ggml_tensor * kq_b,
12251225
ggml_tensor * kq_mask,
1226-
ggml_tensor * sinks,
12271226
ggml_tensor * v_mla,
1227+
ggml_tensor * sinks,
12281228
float kq_scale) const {
12291229
const bool v_trans = v->nb[1] > v->nb[2];
12301230

@@ -1360,7 +1360,6 @@ ggml_tensor * llm_graph_context::build_attn(
13601360
ggml_tensor * k_cur,
13611361
ggml_tensor * v_cur,
13621362
ggml_tensor * kq_b,
1363-
ggml_tensor * sinks,
13641363
ggml_tensor * v_mla,
13651364
float kq_scale,
13661365
int il) const {
@@ -1382,7 +1381,7 @@ ggml_tensor * llm_graph_context::build_attn(
13821381
ggml_tensor * k = k_cur;
13831382
ggml_tensor * v = v_cur;
13841383

1385-
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale);
1384+
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, nullptr, kq_scale);
13861385
cb(cur, "kqv_out", il);
13871386

13881387
if (wo) {
@@ -1444,7 +1443,6 @@ ggml_tensor * llm_graph_context::build_attn(
14441443
ggml_tensor * k_cur,
14451444
ggml_tensor * v_cur,
14461445
ggml_tensor * kq_b,
1447-
ggml_tensor * sinks,
14481446
ggml_tensor * v_mla,
14491447
float kq_scale,
14501448
int il) const {
@@ -1471,7 +1469,7 @@ ggml_tensor * llm_graph_context::build_attn(
14711469
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
14721470
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
14731471

1474-
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale);
1472+
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, nullptr, kq_scale);
14751473
cb(cur, "kqv_out", il);
14761474

14771475
if (wo) {
@@ -1497,10 +1495,35 @@ ggml_tensor * llm_graph_context::build_attn(
14971495
ggml_tensor * k_cur,
14981496
ggml_tensor * v_cur,
14991497
ggml_tensor * kq_b,
1500-
ggml_tensor * sinks,
15011498
ggml_tensor * v_mla,
15021499
float kq_scale,
15031500
int il) const {
1501+
return build_attn_with_sinks(
1502+
inp,
1503+
wo,
1504+
wo_b,
1505+
q_cur,
1506+
k_cur,
1507+
v_cur,
1508+
kq_b,
1509+
v_mla,
1510+
nullptr,
1511+
kq_scale,
1512+
il);
1513+
}
1514+
1515+
ggml_tensor * llm_graph_context::build_attn_with_sinks(
1516+
llm_graph_input_attn_kv_iswa * inp,
1517+
ggml_tensor * wo,
1518+
ggml_tensor * wo_b,
1519+
ggml_tensor * q_cur,
1520+
ggml_tensor * k_cur,
1521+
ggml_tensor * v_cur,
1522+
ggml_tensor * kq_b,
1523+
ggml_tensor * v_mla,
1524+
ggml_tensor * sinks,
1525+
float kq_scale,
1526+
int il) const {
15041527
// these nodes are added to the graph together so that they are not reordered
15051528
// by doing so, the number of splits in the graph is reduced
15061529
ggml_build_forward_expand(gf, q_cur);
@@ -1538,7 +1561,7 @@ ggml_tensor * llm_graph_context::build_attn(
15381561
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
15391562
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
15401563

1541-
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale);
1564+
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, sinks, kq_scale);
15421565
cb(cur, "kqv_out", il);
15431566

15441567
if (wo) {
@@ -1577,7 +1600,6 @@ ggml_tensor * llm_graph_context::build_attn(
15771600
ggml_tensor * k_cur,
15781601
ggml_tensor * v_cur,
15791602
ggml_tensor * kq_b,
1580-
ggml_tensor * sinks,
15811603
ggml_tensor * v_mla,
15821604
float kq_scale,
15831605
int il) const {
@@ -1593,7 +1615,7 @@ ggml_tensor * llm_graph_context::build_attn(
15931615
ggml_tensor * k = k_cur;
15941616
ggml_tensor * v = v_cur;
15951617

1596-
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale);
1618+
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, nullptr, kq_scale);
15971619
cb(cur, "kqv_out", il);
15981620

15991621
if (wo) {

src/llama-graph.h

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -680,14 +680,14 @@ struct llm_graph_context {
680680
//
681681

682682
ggml_tensor * build_attn_mha(
683-
ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
684-
ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
685-
ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
686-
ggml_tensor * kq_b,
687-
ggml_tensor * kq_mask,
688-
ggml_tensor * sinks, // [n_head_q]
689-
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
690-
float kq_scale) const;
683+
ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
684+
ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
685+
ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
686+
ggml_tensor * kq_b,
687+
ggml_tensor * kq_mask,
688+
ggml_tensor * sinks,
689+
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
690+
float kq_scale) const;
691691

692692
llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
693693

@@ -699,7 +699,6 @@ struct llm_graph_context {
699699
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
700700
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
701701
ggml_tensor * kq_b,
702-
ggml_tensor * sinks, // [n_head_q]
703702
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
704703
float kq_scale,
705704
int il) const;
@@ -714,7 +713,6 @@ struct llm_graph_context {
714713
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
715714
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
716715
ggml_tensor * kq_b,
717-
ggml_tensor * sinks, // [n_head_q]
718716
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
719717
float kq_scale,
720718
int il) const;
@@ -730,11 +728,24 @@ struct llm_graph_context {
730728
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
731729
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
732730
ggml_tensor * kq_b,
733-
ggml_tensor * sinks, // [n_head_q]
734731
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
735732
float kq_scale,
736733
int il) const;
737734

735+
// TODO: temporary to keep the diff small. after the code is public will refactor to simplify this
736+
ggml_tensor * build_attn_with_sinks(
737+
llm_graph_input_attn_kv_iswa * inp,
738+
ggml_tensor * wo,
739+
ggml_tensor * wo_b,
740+
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
741+
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
742+
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
743+
ggml_tensor * kq_b,
744+
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
745+
ggml_tensor * sinks, // [n_head_q]
746+
float kq_scale,
747+
int il) const;
748+
738749
llm_graph_input_attn_cross * build_attn_inp_cross() const;
739750

740751
ggml_tensor * build_attn(
@@ -745,7 +756,6 @@ struct llm_graph_context {
745756
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
746757
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
747758
ggml_tensor * kq_b,
748-
ggml_tensor * sinks, // [n_head_q]
749759
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
750760
float kq_scale,
751761
int il) const;

0 commit comments

Comments
 (0)