Skip to content

Commit 3dec397

Browse files
committed
cont : add comments
ggml-ci
1 parent 85a5ea3 commit 3dec397

File tree

2 files changed

+13
-6
lines changed

2 files changed

+13
-6
lines changed

src/llama-kv-cache.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1022,15 +1022,15 @@ ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggm
10221022

10231023
const int64_t n_embd_head = k_cur->ne[0];
10241024
const int64_t n_head = k_cur->ne[1];
1025-
const int64_t n_token = k_cur->ne[2];
1025+
const int64_t n_tokens = k_cur->ne[2];
10261026

10271027
const int64_t n_embd_gqa = n_embd_head*n_head;
10281028

10291029
// we can merge dims 0 and 1
10301030
// TODO: add ggml helper function for this?
10311031
assert(ggml_row_size(k_cur->type, n_embd_head) == k_cur->nb[1]);
10321032

1033-
k_cur = ggml_view_2d(ctx, k_cur, n_embd_gqa, n_token , k_cur->nb[2], 0);
1033+
k_cur = ggml_view_2d(ctx, k_cur, n_embd_gqa, n_tokens, k_cur->nb[2], 0);
10341034

10351035
const int64_t n_stream = k->ne[2];
10361036

@@ -1057,7 +1057,7 @@ ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggm
10571057

10581058
const int64_t n_embd_head = v_cur->ne[0];
10591059
const int64_t n_head = v_cur->ne[1];
1060-
const int64_t n_token = v_cur->ne[2];
1060+
const int64_t n_tokens = v_cur->ne[2];
10611061

10621062
const int64_t n_embd_gqa = n_embd_head*n_head;
10631063

@@ -1068,7 +1068,7 @@ ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggm
10681068

10691069
// take this branch when FA is enabled (the V cache is not transposed)
10701070
if (!v_trans) {
1071-
v_cur = ggml_view_2d(ctx, v_cur, n_embd_gqa, n_token, v_cur->nb[2], 0);
1071+
v_cur = ggml_view_2d(ctx, v_cur, n_embd_gqa, n_tokens, v_cur->nb[2], 0);
10721072

10731073
if (n_stream > 1) {
10741074
const uint64_t kv_size = get_size();
@@ -1085,10 +1085,10 @@ ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggm
10851085

10861086
if (ggml_row_size(v_cur->type, n_embd_gqa) == v_cur->nb[2]) {
10871087
// we can merge dims 0, 1 and 2
1088-
v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_gqa, n_token);
1088+
v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_gqa, n_tokens);
10891089
} else {
10901090
// otherwise -> make a copy to get contiguous data
1091-
v_cur = ggml_cont_2d (ctx, v_cur, n_embd_gqa, n_token);
1091+
v_cur = ggml_cont_2d (ctx, v_cur, n_embd_gqa, n_tokens);
10921092
}
10931093

10941094
// [TAG_V_CACHE_VARIABLE]

src/llama-kv-cache.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,9 +317,16 @@ class llama_kv_cache_context : public llama_memory_context_i {
317317
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
318318

319319
// store k_cur and v_cur in the cache based on the provided head location
320+
// - k_cur [n_embd_head_k, n_head_k, n_tokens]
321+
// - k_idxs [n_batch]
322+
// - v_cur [n_embd_head_v, n_head_v, n_tokens]
323+
// - v_idxs [n_batch] or [n_batch*n_embd_v_gqa] depending if V cache is transposed
320324
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const;
321325
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const;
322326

327+
// create destination indices for each head of the current batch for where it would be written in the KV cache
328+
// the indices address the global KV cache (not per stream) - this is not relevant for the user of this API, but
329+
// helps understand the implementation logic of cpy_k and cpy_v
323330
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
324331
ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
325332

0 commit comments

Comments
 (0)