Skip to content

Commit 3f196be

Browse files
authored
graph : remove build_attn_with_sinks overload (ggml-org#15469)
ggml-ci
1 parent 97ae596 commit 3f196be

File tree

3 files changed

+107
-133
lines changed

3 files changed

+107
-133
lines changed

src/llama-graph.cpp

Lines changed: 9 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1223,8 +1223,8 @@ ggml_tensor * llm_graph_context::build_attn_mha(
12231223
ggml_tensor * v,
12241224
ggml_tensor * kq_b,
12251225
ggml_tensor * kq_mask,
1226-
ggml_tensor * v_mla,
12271226
ggml_tensor * sinks,
1227+
ggml_tensor * v_mla,
12281228
float kq_scale) const {
12291229
const bool v_trans = v->nb[1] > v->nb[2];
12301230

@@ -1360,6 +1360,7 @@ ggml_tensor * llm_graph_context::build_attn(
13601360
ggml_tensor * k_cur,
13611361
ggml_tensor * v_cur,
13621362
ggml_tensor * kq_b,
1363+
ggml_tensor * sinks,
13631364
ggml_tensor * v_mla,
13641365
float kq_scale,
13651366
int il) const {
@@ -1381,7 +1382,7 @@ ggml_tensor * llm_graph_context::build_attn(
13811382
ggml_tensor * k = k_cur;
13821383
ggml_tensor * v = v_cur;
13831384

1384-
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, nullptr, kq_scale);
1385+
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale);
13851386
cb(cur, "kqv_out", il);
13861387

13871388
if (wo) {
@@ -1443,6 +1444,7 @@ ggml_tensor * llm_graph_context::build_attn(
14431444
ggml_tensor * k_cur,
14441445
ggml_tensor * v_cur,
14451446
ggml_tensor * kq_b,
1447+
ggml_tensor * sinks,
14461448
ggml_tensor * v_mla,
14471449
float kq_scale,
14481450
int il) const {
@@ -1469,7 +1471,7 @@ ggml_tensor * llm_graph_context::build_attn(
14691471
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
14701472
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
14711473

1472-
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, nullptr, kq_scale);
1474+
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale);
14731475
cb(cur, "kqv_out", il);
14741476

14751477
if (wo) {
@@ -1495,33 +1497,8 @@ ggml_tensor * llm_graph_context::build_attn(
14951497
ggml_tensor * k_cur,
14961498
ggml_tensor * v_cur,
14971499
ggml_tensor * kq_b,
1498-
ggml_tensor * v_mla,
1499-
float kq_scale,
1500-
int il) const {
1501-
return build_attn_with_sinks(
1502-
inp,
1503-
wo,
1504-
wo_b,
1505-
q_cur,
1506-
k_cur,
1507-
v_cur,
1508-
kq_b,
1509-
v_mla,
1510-
nullptr,
1511-
kq_scale,
1512-
il);
1513-
}
1514-
1515-
ggml_tensor * llm_graph_context::build_attn_with_sinks(
1516-
llm_graph_input_attn_kv_iswa * inp,
1517-
ggml_tensor * wo,
1518-
ggml_tensor * wo_b,
1519-
ggml_tensor * q_cur,
1520-
ggml_tensor * k_cur,
1521-
ggml_tensor * v_cur,
1522-
ggml_tensor * kq_b,
1523-
ggml_tensor * v_mla,
15241500
ggml_tensor * sinks,
1501+
ggml_tensor * v_mla,
15251502
float kq_scale,
15261503
int il) const {
15271504
// these nodes are added to the graph together so that they are not reordered
@@ -1561,7 +1538,7 @@ ggml_tensor * llm_graph_context::build_attn_with_sinks(
15611538
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
15621539
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
15631540

1564-
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, sinks, kq_scale);
1541+
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale);
15651542
cb(cur, "kqv_out", il);
15661543

15671544
if (wo) {
@@ -1600,6 +1577,7 @@ ggml_tensor * llm_graph_context::build_attn(
16001577
ggml_tensor * k_cur,
16011578
ggml_tensor * v_cur,
16021579
ggml_tensor * kq_b,
1580+
ggml_tensor * sinks,
16031581
ggml_tensor * v_mla,
16041582
float kq_scale,
16051583
int il) const {
@@ -1615,7 +1593,7 @@ ggml_tensor * llm_graph_context::build_attn(
16151593
ggml_tensor * k = k_cur;
16161594
ggml_tensor * v = v_cur;
16171595

1618-
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, nullptr, kq_scale);
1596+
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale);
16191597
cb(cur, "kqv_out", il);
16201598

16211599
if (wo) {

src/llama-graph.h

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -680,14 +680,14 @@ struct llm_graph_context {
680680
//
681681

682682
ggml_tensor * build_attn_mha(
683-
ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
684-
ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
685-
ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
686-
ggml_tensor * kq_b,
687-
ggml_tensor * kq_mask,
688-
ggml_tensor * sinks,
689-
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
690-
float kq_scale) const;
683+
ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
684+
ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
685+
ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
686+
ggml_tensor * kq_b,
687+
ggml_tensor * kq_mask,
688+
ggml_tensor * sinks, // [n_head_q]
689+
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
690+
float kq_scale) const;
691691

692692
llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
693693

@@ -699,6 +699,7 @@ struct llm_graph_context {
699699
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
700700
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
701701
ggml_tensor * kq_b,
702+
ggml_tensor * sinks, // [n_head_q]
702703
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
703704
float kq_scale,
704705
int il) const;
@@ -713,6 +714,7 @@ struct llm_graph_context {
713714
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
714715
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
715716
ggml_tensor * kq_b,
717+
ggml_tensor * sinks, // [n_head_q]
716718
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
717719
float kq_scale,
718720
int il) const;
@@ -728,21 +730,8 @@ struct llm_graph_context {
728730
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
729731
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
730732
ggml_tensor * kq_b,
731-
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
732-
float kq_scale,
733-
int il) const;
734-
735-
// TODO: temporary to keep the diff small. after the code is public will refactor to simplify this
736-
ggml_tensor * build_attn_with_sinks(
737-
llm_graph_input_attn_kv_iswa * inp,
738-
ggml_tensor * wo,
739-
ggml_tensor * wo_b,
740-
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
741-
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
742-
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
743-
ggml_tensor * kq_b,
744-
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
745733
ggml_tensor * sinks, // [n_head_q]
734+
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
746735
float kq_scale,
747736
int il) const;
748737

@@ -756,6 +745,7 @@ struct llm_graph_context {
756745
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
757746
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
758747
ggml_tensor * kq_b,
748+
ggml_tensor * sinks, // [n_head_q]
759749
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
760750
float kq_scale,
761751
int il) const;

0 commit comments

Comments
 (0)