Skip to content

Commit 4299536

Browse files
committed
cont : simplify logic for KV cache store/view in attn
ggml-ci
1 parent 4feadaa commit 4299536

File tree

4 files changed

+122
-121
lines changed

4 files changed

+122
-121
lines changed

src/llama-graph.cpp

Lines changed: 18 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -1058,16 +1058,12 @@ ggml_tensor * llm_graph_context::build_attn_mha(
10581058
ggml_tensor * kq_b,
10591059
ggml_tensor * kq_mask,
10601060
ggml_tensor * v_mla,
1061-
bool v_trans,
10621061
float kq_scale) const {
1063-
//const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
1064-
//const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1062+
const bool v_trans = v->nb[1] > v->nb[2];
10651063

1066-
//const int64_t n_head = hparams.n_head(il);
1067-
//const int64_t n_head_kv = hparams.n_head_kv(il);
1068-
1069-
//const auto & n_embd_head_k = hparams.n_embd_head_k;
1070-
//const auto & n_embd_head_v = hparams.n_embd_head_v;
1064+
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
1065+
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
1066+
v = ggml_permute(ctx0, v, 0, 2, 1, 3);
10711067

10721068
const auto n_tokens = q->ne[1];
10731069
const auto n_head = q->ne[2];
@@ -1206,17 +1202,11 @@ ggml_tensor * llm_graph_context::build_attn(
12061202

12071203
const auto & kq_mask = inp->get_kq_mask();
12081204

1209-
ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
1210-
//cb(q, "q", il);
1211-
1212-
ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
1213-
//cb(k, "k", il);
1214-
1215-
ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
1216-
//cb(k, "v", il);
1217-
1218-
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);
1205+
ggml_tensor * q = q_cur;
1206+
ggml_tensor * k = k_cur;
1207+
ggml_tensor * v = v_cur;
12191208

1209+
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
12201210
cb(cur, "kqv_out", il);
12211211

12221212
if (wo) {
@@ -1284,82 +1274,21 @@ ggml_tensor * llm_graph_context::build_attn(
12841274

12851275
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
12861276

1287-
const auto & kv_layer = kv_self->get_layer(il);
1288-
1289-
const auto & n_ctx = cparams.n_ctx;
1290-
1291-
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
1292-
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1293-
1294-
const auto n_tokens = q_cur->ne[2];
1295-
1296-
const bool v_trans = !cparams.flash_attn;
1297-
12981277
// store to KV cache
12991278
{
1300-
const auto kv_head = kv_self->head;
1301-
1302-
GGML_ASSERT(kv_self->size == n_ctx);
1303-
1304-
ggml_tensor * k_cache_view = ggml_view_1d(ctx0, kv_layer.k, n_tokens*n_embd_k_gqa, ggml_row_size(kv_layer.k->type, n_embd_k_gqa)*kv_head);
1305-
//cb(k_cache_view, "k_cache_view", il);
1306-
1307-
// note: storing RoPE-ed version of K in the KV cache
1308-
ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view));
1309-
1310-
v_cur = ggml_reshape_2d(ctx0, v_cur, n_embd_v_gqa, n_tokens);
1311-
1312-
ggml_tensor * v_cache_view = nullptr;
1313-
1314-
if (!v_trans) {
1315-
v_cache_view = ggml_view_1d(ctx0, kv_layer.v, n_tokens*n_embd_v_gqa, ggml_row_size(kv_layer.v->type, n_embd_v_gqa)*kv_head);
1316-
} else {
1317-
// note: the V cache is transposed when not using flash attention
1318-
v_cache_view = ggml_view_2d(ctx0, kv_layer.v, n_tokens, n_embd_v_gqa,
1319-
( n_ctx)*ggml_element_size(kv_layer.v),
1320-
(kv_head)*ggml_element_size(kv_layer.v));
1321-
1322-
v_cur = ggml_transpose(ctx0, v_cur);
1323-
}
1324-
//cb(v_cache_view, "v_cache_view", il);
1325-
1326-
ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view));
1279+
ggml_build_forward_expand(gf, kv_self->cpy_k(ctx0, k_cur, il));
1280+
ggml_build_forward_expand(gf, kv_self->cpy_v(ctx0, v_cur, il));
13271281
}
13281282

13291283
const bool is_swa = hparams.is_swa(il);
1330-
const int64_t n_head_kv = hparams.n_head_kv(il);
13311284

13321285
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
13331286

1334-
const auto n_kv = kv_self->n;
1287+
ggml_tensor * q = q_cur;
1288+
ggml_tensor * k = kv_self->get_k(ctx0, il);
1289+
ggml_tensor * v = kv_self->get_v(ctx0, il);
13351290

1336-
const auto & n_embd_head_k = hparams.n_embd_head_k;
1337-
const auto & n_embd_head_v = hparams.n_embd_head_v;
1338-
1339-
ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
1340-
//cb(q, "q", il);
1341-
1342-
ggml_tensor * k =
1343-
ggml_view_3d(ctx0, kv_layer.k,
1344-
n_embd_head_k, n_kv, n_head_kv,
1345-
ggml_row_size(kv_layer.k->type, n_embd_k_gqa),
1346-
ggml_row_size(kv_layer.k->type, n_embd_head_k),
1347-
0);
1348-
//cb(k, "k", il);
1349-
1350-
ggml_tensor * v = !v_trans ?
1351-
ggml_view_3d(ctx0, kv_layer.v,
1352-
n_embd_head_v, n_kv, n_head_kv,
1353-
ggml_row_size(kv_layer.v->type, n_embd_v_gqa),
1354-
ggml_row_size(kv_layer.v->type, n_embd_head_v),
1355-
0) :
1356-
ggml_view_3d(ctx0, kv_layer.v,
1357-
n_kv, n_embd_head_v, n_head_kv,
1358-
ggml_element_size(kv_layer.v)*n_ctx,
1359-
ggml_element_size(kv_layer.v)*n_ctx*n_embd_head_v,
1360-
0);
1361-
1362-
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, v_trans, kq_scale);
1291+
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
13631292
cb(cur, "kqv_out", il);
13641293

13651294
if (wo) {
@@ -1410,17 +1339,11 @@ ggml_tensor * llm_graph_context::build_attn(
14101339

14111340
const auto & kq_mask = inp->get_kq_mask_cross();
14121341

1413-
ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
1414-
//cb(q, "q", il);
1415-
1416-
ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
1417-
//cb(k, "k", il);
1418-
1419-
ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
1420-
//cb(k, "v", il);
1421-
1422-
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);
1342+
ggml_tensor * q = q_cur;
1343+
ggml_tensor * k = k_cur;
1344+
ggml_tensor * v = v_cur;
14231345

1346+
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
14241347
cb(cur, "kqv_out", il);
14251348

14261349
if (wo) {

src/llama-graph.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -504,13 +504,12 @@ struct llm_graph_context {
504504

505505
ggml_tensor * build_attn_mha(
506506
ggml_cgraph * gf,
507-
ggml_tensor * q, // [n_embd_head_q, n_tokens, n_head_q]
508-
ggml_tensor * k, // [n_embd_head_k, n_tokens, n_head_k]
509-
ggml_tensor * v, // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false)
507+
ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
508+
ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
509+
ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
510510
ggml_tensor * kq_b,
511511
ggml_tensor * kq_mask,
512-
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
513-
bool v_trans,
512+
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
514513
float kq_scale) const;
515514

516515
llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;

src/llama-kv-cache.cpp

Lines changed: 82 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,73 @@ const llama_kv_cache_unified::kv_layer & llama_kv_cache_unified::get_layer(int32
565565
return layers[il];
566566
}
567567

568+
ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il) const {
569+
auto * k = layers[il].k;
570+
571+
return ggml_view_3d(ctx, k,
572+
hparams.n_embd_head_k, hparams.n_head_kv(il), n,
573+
ggml_row_size(k->type, hparams.n_embd_head_k),
574+
ggml_row_size(k->type, hparams.n_embd_k_gqa(il)),
575+
0);
576+
}
577+
578+
ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il) const {
579+
auto * v = layers[il].v;
580+
581+
if (!v_trans) {
582+
// note: v->nb[1] <= v->nb[2]
583+
return ggml_view_3d(ctx, v,
584+
hparams.n_embd_head_v, hparams.n_head_kv(il), n,
585+
ggml_row_size(v->type, hparams.n_embd_head_v), // v->nv[1]
586+
ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2]
587+
0);
588+
}
589+
590+
// note: v->nb[1] > v->nb[2]
591+
return ggml_view_3d(ctx, v,
592+
n, hparams.n_head_kv(il), hparams.n_embd_head_v,
593+
ggml_element_size(v)*v->ne[1]*hparams.n_embd_head_v, // v->nb[1]
594+
ggml_element_size(v)*v->ne[1], // v->nb[2]
595+
0);
596+
}
597+
598+
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
599+
auto * k = layers[il].k;
600+
601+
const int64_t n_tokens = k_cur->ne[2];
602+
603+
ggml_tensor * k_view = ggml_view_1d(ctx, k,
604+
n_tokens*hparams.n_embd_k_gqa(il),
605+
ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head);
606+
607+
return ggml_cpy(ctx, k_cur, k_view);
608+
}
609+
610+
ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
611+
auto * v = layers[il].v;
612+
613+
const int64_t n_tokens = v_cur->ne[2];
614+
615+
v_cur = ggml_reshape_2d(ctx, v_cur, hparams.n_embd_v_gqa(il), n_tokens);
616+
617+
ggml_tensor * v_view = nullptr;
618+
619+
if (!v_trans) {
620+
v_view = ggml_view_1d(ctx, v,
621+
n_tokens*hparams.n_embd_v_gqa(il),
622+
ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head);
623+
} else {
624+
// note: the V cache is transposed when not using flash attention
625+
v_view = ggml_view_2d(ctx, v, n_tokens, hparams.n_embd_v_gqa(il),
626+
(v->ne[1])*ggml_element_size(v),
627+
( head)*ggml_element_size(v));
628+
629+
v_cur = ggml_transpose(ctx, v_cur);
630+
}
631+
632+
return ggml_cpy(ctx, v_cur, v_view);
633+
}
634+
568635
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
569636
const int64_t n_tokens = ubatch->n_tokens;
570637
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
@@ -633,7 +700,7 @@ void llama_kv_cache_unified::set_input_kq_mask_swa(ggml_tensor * dst, const llam
633700
const int64_t n_seqs = ubatch->n_seqs;
634701

635702
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
636-
float * data_swa = (float *) dst->data;
703+
float * data = (float *) dst->data;
637704

638705
const int64_t n_kv = n;
639706

@@ -661,28 +728,26 @@ void llama_kv_cache_unified::set_input_kq_mask_swa(ggml_tensor * dst, const llam
661728

662729
// may need to cut off old tokens for sliding window
663730
// TODO @ngxson : we are currently re-using the swa logic to store the chunked mask, we should rename SWA to something more generic like "aux mask"
664-
if (data_swa) {
665-
if (hparams.n_attn_chunk) {
666-
llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
667-
if (cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
668-
f = -INFINITY;
669-
}
670-
} else {
671-
if (pos - cells[i].pos >= (int32_t)hparams.n_swa) {
672-
f = -INFINITY;
673-
}
731+
if (hparams.n_attn_chunk) {
732+
llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
733+
if (cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
734+
f = -INFINITY;
735+
}
736+
} else {
737+
if (pos - cells[i].pos >= (int32_t)hparams.n_swa) {
738+
f = -INFINITY;
674739
}
675-
data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
676740
}
741+
data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
677742
}
678743
}
679744
}
680745

681746
// mask padded tokens
682-
if (data_swa) {
747+
if (data) {
683748
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
684749
for (int j = 0; j < n_kv; ++j) {
685-
data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
750+
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
686751
}
687752
}
688753
}
@@ -1296,12 +1361,12 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
12961361
}
12971362
}
12981363
} else {
1364+
// When v is transposed, we also need the element size and get the element ranges from each row
1365+
const uint32_t kv_size = size;
1366+
12991367
for (uint32_t il = 0; il < n_layer; ++il) {
13001368
const auto & layer = layers[il];
13011369

1302-
// When v is transposed, we also need the element size and get the element ranges from each row
1303-
const uint32_t kv_size = size;
1304-
13051370
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
13061371

13071372
// Write value type

src/llama-kv-cache.h

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,6 @@ class llama_kv_cache_unified : public llama_kv_cache {
133133

134134
~llama_kv_cache_unified() = default;
135135

136-
// find how many cells are currently in use
137-
uint32_t cell_max() const;
138-
139136
// Note: The value of head isn't only used to optimize searching
140137
// for a free KV slot. llama_decode_impl also uses it, so it
141138
// cannot be freely changed after a slot has been allocated.
@@ -146,7 +143,6 @@ class llama_kv_cache_unified : public llama_kv_cache {
146143
// computed before each graph build
147144
uint32_t n = 0;
148145

149-
150146
//
151147
// llama_memory_i
152148
//
@@ -193,8 +189,15 @@ class llama_kv_cache_unified : public llama_kv_cache {
193189

194190
const kv_layer & get_layer(int32_t il) const;
195191

192+
ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
193+
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
194+
195+
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const;
196+
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const;
197+
196198
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
197199
void set_input_kq_mask_swa(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
200+
198201
void set_input_k_shift (ggml_tensor * dst) const;
199202
void set_input_pos_bucket (ggml_tensor * dst, const llama_ubatch * ubatch) const;
200203

@@ -238,6 +241,9 @@ class llama_kv_cache_unified : public llama_kv_cache {
238241
// return true if cells have been moved
239242
bool defrag_prepare(int32_t n_max_nodes);
240243

244+
// find how many cells are currently in use
245+
uint32_t cell_max() const;
246+
241247
size_t total_size() const;
242248

243249
size_t size_k_bytes() const;
@@ -269,6 +275,14 @@ class llama_kv_cache_unified : public llama_kv_cache {
269275
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
270276
};
271277

278+
//
279+
// llama_kv_cache_unified_swa
280+
//
281+
282+
//class llama_kv_cache_unified_swa : public llama_kv_cache {
283+
//public:
284+
//};
285+
272286
//
273287
// llama_kv_cache_recurrent
274288
//

0 commit comments

Comments
 (0)