Skip to content

Commit 85a5ea3

Browse files
committed
kv-cache : make cpy_k and cpy_v implementation more readable
ggml-ci
1 parent bb1202b commit 85a5ea3

File tree

1 file changed

+47
-23
lines changed

1 file changed

+47
-23
lines changed

src/llama-kv-cache.cpp

Lines changed: 47 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1018,21 +1018,33 @@ ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggm
10181018

10191019
const int32_t ikv = map_layer_ids.at(il);
10201020

1021-
auto * k = layers[ikv].k;
1021+
ggml_tensor * k = layers[ikv].k;
1022+
1023+
const int64_t n_embd_head = k_cur->ne[0];
1024+
const int64_t n_head = k_cur->ne[1];
1025+
const int64_t n_token = k_cur->ne[2];
10221026

1023-
const int64_t n_embd_k_gqa = k_cur->ne[0]*k_cur->ne[1];
1024-
const int64_t n_tokens = k_cur->ne[2];
1027+
const int64_t n_embd_gqa = n_embd_head*n_head;
10251028

10261029
// we can merge dims 0 and 1
1027-
assert(k_cur->nb[0]*k_cur->ne[0] == k_cur->nb[1]);
1030+
// TODO: add ggml helper function for this?
1031+
assert(ggml_row_size(k_cur->type, n_embd_head) == k_cur->nb[1]);
1032+
1033+
k_cur = ggml_view_2d(ctx, k_cur, n_embd_gqa, n_token , k_cur->nb[2], 0);
1034+
1035+
const int64_t n_stream = k->ne[2];
1036+
1037+
if (n_stream > 1) {
1038+
const uint64_t kv_size = get_size();
10281039

1029-
//k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
1030-
k_cur = ggml_view_2d(ctx, k_cur, n_embd_k_gqa, n_tokens, k_cur->nb[2], 0);
1040+
assert(n_embd_gqa == k->ne[0]);
1041+
assert(kv_size == k->ne[1]);
10311042

1032-
if (k->ne[2] > 1) {
1033-
k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1]*k->ne[2]);
1043+
// merge the buffer across all streams because the idxs are global
1044+
k = ggml_reshape_2d(ctx, k, n_embd_gqa, kv_size*n_stream);
10341045
}
10351046

1047+
// store the current K values into the cache
10361048
return ggml_set_rows(ctx, k, k_cur, k_idxs);
10371049
}
10381050

@@ -1043,39 +1055,51 @@ ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggm
10431055

10441056
auto * v = layers[ikv].v;
10451057

1046-
const int64_t n_embd_v_gqa = v_cur->ne[0]*v_cur->ne[1];
1047-
const int64_t n_tokens = v_cur->ne[2];
1058+
const int64_t n_embd_head = v_cur->ne[0];
1059+
const int64_t n_head = v_cur->ne[1];
1060+
const int64_t n_token = v_cur->ne[2];
1061+
1062+
const int64_t n_embd_gqa = n_embd_head*n_head;
10481063

10491064
// we can merge dims 0 and 1
1050-
assert(v_cur->nb[0]*v_cur->ne[0] == v_cur->nb[1]);
1065+
assert(ggml_row_size(v_cur->type, n_embd_head) == v_cur->nb[1]);
10511066

1052-
//v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
1067+
const int64_t n_stream = v->ne[2];
10531068

1069+
// take this branch when FA is enabled (the V cache is not transposed)
10541070
if (!v_trans) {
1055-
v_cur = ggml_view_2d(ctx, v_cur, n_embd_v_gqa, n_tokens, v_cur->nb[2], 0);
1071+
v_cur = ggml_view_2d(ctx, v_cur, n_embd_gqa, n_token, v_cur->nb[2], 0);
1072+
1073+
if (n_stream > 1) {
1074+
const uint64_t kv_size = get_size();
1075+
1076+
assert(n_embd_gqa == v->ne[0]);
1077+
assert(kv_size == v->ne[1]);
10561078

1057-
if (v->ne[2] > 1) {
1058-
v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1]*v->ne[2]);
1079+
// merge the buffer across all streams because the idxs are global
1080+
v = ggml_reshape_2d(ctx, v, n_embd_gqa, kv_size*n_stream);
10591081
}
10601082

10611083
return ggml_set_rows(ctx, v, v_cur, v_idxs);
10621084
}
10631085

1064-
if (v_cur->nb[1]*v_cur->ne[1] != v_cur->nb[2]) {
1065-
v_cur = ggml_cont_2d (ctx, v_cur, n_embd_v_gqa, n_tokens);
1086+
if (ggml_row_size(v_cur->type, n_embd_gqa) == v_cur->nb[2]) {
1087+
// we can merge dims 0, 1 and 2
1088+
v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_gqa, n_token);
10661089
} else {
1067-
v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
1090+
// otherwise -> make a copy to get contiguous data
1091+
v_cur = ggml_cont_2d (ctx, v_cur, n_embd_gqa, n_token);
10681092
}
10691093

10701094
// [TAG_V_CACHE_VARIABLE]
1071-
if (n_embd_v_gqa < v->ne[0]) {
1072-
v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_v_gqa, 0, 0, 0);
1095+
if (n_embd_gqa < v->ne[0]) {
1096+
v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_gqa, 0, 0, 0);
10731097
}
10741098

1075-
// the row becomes a single element
1076-
ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, v->ne[0]*v->ne[1]*v->ne[2]);
1099+
// in this branch the v_idxs are constructed in such a way that each row becomes a single element
1100+
ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, ggml_nelements(v));
10771101

1078-
v_cur = ggml_reshape_2d(ctx, v_cur, 1, v_cur->ne[0]*v_cur->ne[1]);
1102+
v_cur = ggml_reshape_2d(ctx, v_cur, 1, ggml_nelements(v_cur));
10791103

10801104
return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
10811105
}

0 commit comments

Comments
 (0)