Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 9 additions & 31 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1223,8 +1223,8 @@ ggml_tensor * llm_graph_context::build_attn_mha(
ggml_tensor * v,
ggml_tensor * kq_b,
ggml_tensor * kq_mask,
ggml_tensor * v_mla,
ggml_tensor * sinks,
ggml_tensor * v_mla,
float kq_scale) const {
const bool v_trans = v->nb[1] > v->nb[2];

Expand Down Expand Up @@ -1360,6 +1360,7 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * k_cur,
ggml_tensor * v_cur,
ggml_tensor * kq_b,
ggml_tensor * sinks,
ggml_tensor * v_mla,
float kq_scale,
int il) const {
Expand All @@ -1381,7 +1382,7 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * k = k_cur;
ggml_tensor * v = v_cur;

ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, nullptr, kq_scale);
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale);
cb(cur, "kqv_out", il);

if (wo) {
Expand Down Expand Up @@ -1443,6 +1444,7 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * k_cur,
ggml_tensor * v_cur,
ggml_tensor * kq_b,
ggml_tensor * sinks,
ggml_tensor * v_mla,
float kq_scale,
int il) const {
Expand All @@ -1469,7 +1471,7 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
ggml_tensor * v = mctx_cur->get_v(ctx0, il);

ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, nullptr, kq_scale);
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale);
cb(cur, "kqv_out", il);

if (wo) {
Expand All @@ -1495,33 +1497,8 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * k_cur,
ggml_tensor * v_cur,
ggml_tensor * kq_b,
ggml_tensor * v_mla,
float kq_scale,
int il) const {
return build_attn_with_sinks(
inp,
wo,
wo_b,
q_cur,
k_cur,
v_cur,
kq_b,
v_mla,
nullptr,
kq_scale,
il);
}

ggml_tensor * llm_graph_context::build_attn_with_sinks(
llm_graph_input_attn_kv_iswa * inp,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur,
ggml_tensor * k_cur,
ggml_tensor * v_cur,
ggml_tensor * kq_b,
ggml_tensor * v_mla,
ggml_tensor * sinks,
ggml_tensor * v_mla,
float kq_scale,
int il) const {
// these nodes are added to the graph together so that they are not reordered
Expand Down Expand Up @@ -1561,7 +1538,7 @@ ggml_tensor * llm_graph_context::build_attn_with_sinks(
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
ggml_tensor * v = mctx_cur->get_v(ctx0, il);

ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, sinks, kq_scale);
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale);
cb(cur, "kqv_out", il);

if (wo) {
Expand Down Expand Up @@ -1600,6 +1577,7 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * k_cur,
ggml_tensor * v_cur,
ggml_tensor * kq_b,
ggml_tensor * sinks,
ggml_tensor * v_mla,
float kq_scale,
int il) const {
Expand All @@ -1615,7 +1593,7 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * k = k_cur;
ggml_tensor * v = v_cur;

ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, nullptr, kq_scale);
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale);
cb(cur, "kqv_out", il);

if (wo) {
Expand Down
34 changes: 12 additions & 22 deletions src/llama-graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -680,14 +680,14 @@ struct llm_graph_context {
//

ggml_tensor * build_attn_mha(
ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
ggml_tensor * kq_b,
ggml_tensor * kq_mask,
ggml_tensor * sinks,
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
float kq_scale) const;
ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
ggml_tensor * kq_b,
ggml_tensor * kq_mask,
ggml_tensor * sinks, // [n_head_q]
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
float kq_scale) const;

llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;

Expand All @@ -699,6 +699,7 @@ struct llm_graph_context {
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
ggml_tensor * kq_b,
ggml_tensor * sinks, // [n_head_q]
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
float kq_scale,
int il) const;
Expand All @@ -713,6 +714,7 @@ struct llm_graph_context {
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
ggml_tensor * kq_b,
ggml_tensor * sinks, // [n_head_q]
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
float kq_scale,
int il) const;
Expand All @@ -728,21 +730,8 @@ struct llm_graph_context {
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
ggml_tensor * kq_b,
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
float kq_scale,
int il) const;

// TODO: temporary to keep the diff small. after the code is public will refactor to simplify this
ggml_tensor * build_attn_with_sinks(
llm_graph_input_attn_kv_iswa * inp,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
ggml_tensor * kq_b,
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
ggml_tensor * sinks, // [n_head_q]
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
float kq_scale,
int il) const;

Expand All @@ -756,6 +745,7 @@ struct llm_graph_context {
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
ggml_tensor * kq_b,
ggml_tensor * sinks, // [n_head_q]
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
float kq_scale,
int il) const;
Expand Down
Loading
Loading