Skip to content

Commit 2eacb4c

Browse files
committed
graph : simplify attention api
ggml-ci
1 parent e17e4b7 commit 2eacb4c

File tree

4 files changed

+36
-64
lines changed

4 files changed

+36
-64
lines changed

src/llama-context.cpp

Lines changed: 28 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2567,63 +2567,56 @@ void llama_context_kv_self::build_attn_inp(
25672567
}
25682568
}
25692569

2570-
void llama_context_kv_self::build_attn_kv_store(
2570+
ggml_tensor * llama_context_kv_self::build_attn(
25712571
ggml_context * ctx0,
25722572
ggml_cgraph * gf,
2573+
ggml_tensor * wo,
2574+
ggml_tensor * wo_b,
25732575
ggml_tensor * k_cur,
25742576
ggml_tensor * v_cur,
2577+
ggml_tensor * q_cur,
25752578
int32_t n_tokens,
2576-
int64_t il,
2579+
float kq_scale,
2580+
int il,
25772581
bool worst_case) {
25782582
const auto & hparams = model.hparams;
25792583

25802584
const auto & n_ctx = cparams.n_ctx;
25812585

2582-
const auto kv_head = worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head;
2583-
25842586
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
25852587
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
25862588

2587-
GGML_ASSERT(kv_self.size == n_ctx);
2589+
// store to KV cache
2590+
{
2591+
const auto kv_head = worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head;
25882592

2589-
struct ggml_tensor * k_cache_view = ggml_view_1d(ctx0, kv_self.k_l[il], n_tokens*n_embd_k_gqa, ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa)*kv_head);
2590-
//cb(k_cache_view, "k_cache_view", il);
2593+
GGML_ASSERT(kv_self.size == n_ctx);
25912594

2592-
// note: storing RoPE-ed version of K in the KV cache
2593-
ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view));
2595+
struct ggml_tensor * k_cache_view = ggml_view_1d(ctx0, kv_self.k_l[il], n_tokens*n_embd_k_gqa, ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa)*kv_head);
2596+
//cb(k_cache_view, "k_cache_view", il);
25942597

2595-
assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens);
2598+
// note: storing RoPE-ed version of K in the KV cache
2599+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view));
25962600

2597-
struct ggml_tensor * v_cache_view = nullptr;
2601+
assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens);
25982602

2599-
if (cparams.flash_attn) {
2600-
v_cache_view = ggml_view_1d(ctx0, kv_self.v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa)*kv_head);
2601-
} else {
2602-
// note: the V cache is transposed when not using flash attention
2603-
v_cache_view = ggml_view_2d(ctx0, kv_self.v_l[il], n_tokens, n_embd_v_gqa,
2604-
( n_ctx)*ggml_element_size(kv_self.v_l[il]),
2605-
(kv_head)*ggml_element_size(kv_self.v_l[il]));
2603+
struct ggml_tensor * v_cache_view = nullptr;
26062604

2607-
v_cur = ggml_transpose(ctx0, v_cur);
2608-
}
2609-
//cb(v_cache_view, "v_cache_view", il);
2605+
if (cparams.flash_attn) {
2606+
v_cache_view = ggml_view_1d(ctx0, kv_self.v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa)*kv_head);
2607+
} else {
2608+
// note: the V cache is transposed when not using flash attention
2609+
v_cache_view = ggml_view_2d(ctx0, kv_self.v_l[il], n_tokens, n_embd_v_gqa,
2610+
( n_ctx)*ggml_element_size(kv_self.v_l[il]),
2611+
(kv_head)*ggml_element_size(kv_self.v_l[il]));
26102612

2611-
ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view));
2612-
}
2613+
v_cur = ggml_transpose(ctx0, v_cur);
2614+
}
2615+
//cb(v_cache_view, "v_cache_view", il);
26132616

2614-
ggml_tensor * llama_context_kv_self::build_attn_qkv(
2615-
ggml_context * ctx0,
2616-
ggml_cgraph * gf,
2617-
ggml_tensor * wo,
2618-
ggml_tensor * wo_b,
2619-
ggml_tensor * q_cur,
2620-
int32_t n_tokens,
2621-
float kq_scale,
2622-
int il,
2623-
bool worst_case) {
2624-
const auto & hparams = model.hparams;
2617+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view));
2618+
}
26252619

2626-
const auto & n_ctx = cparams.n_ctx;
26272620
const auto & n_embd_head_k = hparams.n_embd_head_k;
26282621
const auto & n_embd_head_v = hparams.n_embd_head_v;
26292622

@@ -2657,8 +2650,6 @@ ggml_tensor * llama_context_kv_self::build_attn_qkv(
26572650

26582651
const int64_t n_head = hparams.n_head(il);
26592652
const int64_t n_head_kv = hparams.n_head_kv(il);
2660-
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
2661-
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
26622653

26632654
struct ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
26642655
//cb(q, "q", il);

src/llama-context.h

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -376,20 +376,13 @@ class llama_context_kv_self : public llama_context {
376376
bool swa,
377377
bool worst_case) override;
378378

379-
virtual void build_attn_kv_store(
380-
ggml_context * ctx0,
381-
ggml_cgraph * gf,
382-
ggml_tensor * k_cur,
383-
ggml_tensor * v_cur,
384-
int32_t n_tokens,
385-
int64_t il,
386-
bool worst_case) override;
387-
388-
virtual ggml_tensor * build_attn_qkv(
379+
virtual ggml_tensor * build_attn(
389380
ggml_context * ctx0,
390381
ggml_cgraph * gf,
391382
ggml_tensor * wo,
392383
ggml_tensor * wo_b,
384+
ggml_tensor * k_cur,
385+
ggml_tensor * v_cur,
393386
ggml_tensor * q_cur,
394387
int32_t n_tokens,
395388
float kq_scale,
@@ -443,6 +436,7 @@ class llama_context_kv_self : public llama_context {
443436

444437
// a recurrent transformer (ie.e RWKV, Mamba)
445438
// TODO: temporary reuse kv_self, but in the future, implement recurrent-specific context with specific cache
439+
//class llama_context_recurrent : public llama_context {
446440
class llama_context_recurrent : public llama_context_kv_self {
447441
public:
448442
llama_context_recurrent(

src/llama-graph.h

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -88,20 +88,13 @@ class llama_graph_i {
8888
bool swa,
8989
bool worst_case) = 0;
9090

91-
virtual void build_attn_kv_store(
92-
ggml_context * ctx0,
93-
ggml_cgraph * gf,
94-
ggml_tensor * k_cur,
95-
ggml_tensor * v_cur,
96-
int32_t n_tokens,
97-
int64_t il,
98-
bool worst_case) = 0;
99-
100-
virtual ggml_tensor * build_attn_qkv(
91+
virtual ggml_tensor * build_attn(
10192
ggml_context * ctx0,
10293
ggml_cgraph * gf,
10394
ggml_tensor * wo,
10495
ggml_tensor * wo_b,
96+
ggml_tensor * k_cur,
97+
ggml_tensor * v_cur,
10598
ggml_tensor * q_cur,
10699
int32_t n_tokens,
107100
float kq_scale,

src/llama-model.cpp

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4258,13 +4258,7 @@ struct llm_build_context {
42584258
ggml_build_forward_expand(gf, k_cur);
42594259
ggml_build_forward_expand(gf, v_cur);
42604260

4261-
//build_kv_store(gf, k_cur, v_cur, il);
4262-
lgf->build_attn_kv_store(ctx0, gf, k_cur, v_cur, n_tokens, il, worst_case);
4263-
4264-
struct ggml_tensor * cur;
4265-
4266-
//cur = build_kqv(gf, wo, wo_b, q_cur, kq_mask, kq_scale, il);
4267-
cur = lgf->build_attn_qkv(ctx0, gf, wo, wo_b, q_cur, n_tokens, kq_scale, il, worst_case);
4261+
ggml_tensor * cur = lgf->build_attn(ctx0, gf, wo, wo_b, k_cur, v_cur, q_cur, n_tokens, kq_scale, il, worst_case);
42684262
cb(cur, "kqv_out", il);
42694263

42704264
return cur;

0 commit comments

Comments
 (0)