Skip to content

Commit cf0e3ba

Browse files
authored
model : avoid ggml_cont_3d for fused QKV weights (ggml-org#15662)
* model : avoid ggml_cont_3d for fused QKV weights ggml-ci * kv-cache : make cpy_k and cpy_v implementation more readable ggml-ci * cont : add comments ggml-ci * cont : minor fix [no ci] * cont : one more fix * cont : clarity ggml-ci * kv-cache : require contiguous heads of k_cur and v_cur ggml-ci
1 parent d413dca commit cf0e3ba

File tree

3 files changed

+99
-107
lines changed

3 files changed

+99
-107
lines changed

src/llama-kv-cache.cpp

Lines changed: 55 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1018,16 +1018,33 @@ ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggm
10181018

10191019
const int32_t ikv = map_layer_ids.at(il);
10201020

1021-
auto * k = layers[ikv].k;
1021+
ggml_tensor * k = layers[ikv].k;
1022+
1023+
const int64_t n_embd_head = k_cur->ne[0];
1024+
const int64_t n_head = k_cur->ne[1];
1025+
const int64_t n_tokens = k_cur->ne[2];
1026+
1027+
const int64_t n_embd_gqa = n_embd_head*n_head;
10221028

1023-
const int64_t n_tokens = k_cur->ne[2];
1029+
// we can merge dims 0 and 1
1030+
// TODO: add ggml helper function for this?
1031+
GGML_ASSERT(ggml_row_size(k_cur->type, n_embd_head) == k_cur->nb[1]);
10241032

1025-
k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
1033+
k_cur = ggml_view_2d(ctx, k_cur, n_embd_gqa, n_tokens, k_cur->nb[2], 0);
10261034

1027-
if (k->ne[2] > 1) {
1028-
k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1]*k->ne[2]);
1035+
const int64_t n_stream = k->ne[2];
1036+
1037+
if (n_stream > 1) {
1038+
const int64_t kv_size = get_size();
1039+
1040+
assert(n_embd_gqa == k->ne[0]);
1041+
assert(kv_size == k->ne[1]);
1042+
1043+
// merge the buffer across all streams because the idxs are global
1044+
k = ggml_reshape_2d(ctx, k, n_embd_gqa, kv_size*n_stream);
10291045
}
10301046

1047+
// store the current K values into the cache
10311048
return ggml_set_rows(ctx, k, k_cur, k_idxs);
10321049
}
10331050

@@ -1038,28 +1055,51 @@ ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggm
10381055

10391056
auto * v = layers[ikv].v;
10401057

1041-
const int64_t n_embd_v_gqa = v_cur->ne[0]*v_cur->ne[1];
1042-
const int64_t n_tokens = v_cur->ne[2];
1058+
const int64_t n_embd_head = v_cur->ne[0];
1059+
const int64_t n_head = v_cur->ne[1];
1060+
const int64_t n_tokens = v_cur->ne[2];
1061+
1062+
const int64_t n_embd_gqa = n_embd_head*n_head;
10431063

1044-
v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
1064+
// we can merge dims 0 and 1
1065+
GGML_ASSERT(ggml_row_size(v_cur->type, n_embd_head) == v_cur->nb[1]);
10451066

1067+
const int64_t n_stream = v->ne[2];
1068+
1069+
// take this branch when FA is enabled (the V cache is not transposed)
10461070
if (!v_trans) {
1047-
if (v->ne[2] > 1) {
1048-
v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1]*v->ne[2]);
1071+
v_cur = ggml_view_2d(ctx, v_cur, n_embd_gqa, n_tokens, v_cur->nb[2], 0);
1072+
1073+
if (n_stream > 1) {
1074+
const int64_t kv_size = get_size();
1075+
1076+
assert(n_embd_gqa == v->ne[0]);
1077+
assert(kv_size == v->ne[1]);
1078+
1079+
// merge the buffer across all streams because the idxs are global
1080+
v = ggml_reshape_2d(ctx, v, n_embd_gqa, kv_size*n_stream);
10491081
}
10501082

10511083
return ggml_set_rows(ctx, v, v_cur, v_idxs);
10521084
}
10531085

1086+
if (ggml_row_size(v_cur->type, n_embd_gqa) == v_cur->nb[2]) {
1087+
// we can merge dims 0, 1 and 2
1088+
v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_gqa, n_tokens);
1089+
} else {
1090+
// otherwise -> make a copy to get contiguous data
1091+
v_cur = ggml_cont_2d (ctx, v_cur, n_embd_gqa, n_tokens);
1092+
}
1093+
10541094
// [TAG_V_CACHE_VARIABLE]
1055-
if (n_embd_v_gqa < v->ne[0]) {
1056-
v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_v_gqa, 0, 0, 0);
1095+
if (n_embd_gqa < v->ne[0]) {
1096+
v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_gqa, 0, 0, 0);
10571097
}
10581098

1059-
// the row becomes a single element
1060-
ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, v->ne[0]*v->ne[1]*v->ne[2]);
1099+
// in this branch the v_idxs are constructed in such a way that each row is a single head element
1100+
ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, ggml_nelements(v));
10611101

1062-
v_cur = ggml_reshape_2d(ctx, v_cur, 1, v_cur->ne[0]*v_cur->ne[1]);
1102+
v_cur = ggml_reshape_2d(ctx, v_cur, 1, ggml_nelements(v_cur));
10631103

10641104
return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
10651105
}

src/llama-kv-cache.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,9 +317,17 @@ class llama_kv_cache_context : public llama_memory_context_i {
317317
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
318318

319319
// store k_cur and v_cur in the cache based on the provided head location
320+
// note: the heads in k_cur and v_cur should be layed out contiguously in memory
321+
// - k_cur [n_embd_head_k, n_head_k, n_tokens]
322+
// - k_idxs [n_tokens]
323+
// - v_cur [n_embd_head_v, n_head_v, n_tokens]
324+
// - v_idxs [n_tokens] or [n_tokens*n_embd_v_gqa] depending if V cache is transposed
320325
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const;
321326
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const;
322327

328+
// create destination indices for each head of the current batch for where it would be written in the KV cache
329+
// the indices address the global KV cache (not per stream) - this is not relevant for the user of this API, but
330+
// helps understand the implementation logic of cpy_k and cpy_v
323331
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
324332
ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
325333

0 commit comments

Comments
 (0)