Skip to content

Commit 016774b

Browse files
committed
cont : maintain map of model layer id -> kv cache layer id
ggml-ci
1 parent 7c5deb0 commit 016774b

File tree

2 files changed

+21
-27
lines changed

2 files changed

+21
-27
lines changed

src/llama-kv-cache.cpp

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -103,23 +103,13 @@ llama_kv_cache_unified::llama_kv_cache_unified(
103103
ggml_tensor * k;
104104
ggml_tensor * v;
105105

106-
// TODO: enable
107-
#if 0
108-
if (hparams.is_swa(il)) {
109-
k = ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, hparams.n_swa);
110-
v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, hparams.n_swa);
111-
} else {
112-
k = ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, kv_size);
113-
v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, kv_size);
114-
}
115-
#else
116106
k = ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, kv_size);
117107
v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, kv_size);
118-
#endif
119108

120109
ggml_format_name(k, "cache_k_l%d", il);
121110
ggml_format_name(v, "cache_v_l%d", il);
122111

112+
map_layer_ids[il] = layers.size();
123113
layers.push_back({ il, k, v });
124114
}
125115

@@ -565,10 +555,10 @@ uint32_t llama_kv_cache_unified::get_n() const {
565555
return n;
566556
}
567557

568-
ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t ikv) const {
569-
auto * k = layers[ikv].k;
558+
ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il) const {
559+
const int32_t ikv = map_layer_ids.at(il);
570560

571-
const uint32_t il = layers[ikv].il;
561+
auto * k = layers[ikv].k;
572562

573563
return ggml_view_3d(ctx, k,
574564
hparams.n_embd_head_k, hparams.n_head_kv(il), n,
@@ -577,10 +567,10 @@ ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t ikv) con
577567
0);
578568
}
579569

580-
ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t ikv) const {
581-
auto * v = layers[ikv].v;
570+
ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il) const {
571+
const int32_t ikv = map_layer_ids.at(il);
582572

583-
const uint32_t il = layers[ikv].il;
573+
auto * v = layers[ikv].v;
584574

585575
if (!v_trans) {
586576
// note: v->nb[1] <= v->nb[2]
@@ -599,10 +589,10 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t ikv) con
599589
0);
600590
}
601591

602-
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t ikv) const {
603-
auto * k = layers[ikv].k;
592+
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
593+
const int32_t ikv = map_layer_ids.at(il);
604594

605-
const uint32_t il = layers[ikv].il;
595+
auto * k = layers[ikv].k;
606596

607597
const int64_t n_tokens = k_cur->ne[2];
608598

@@ -613,10 +603,10 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
613603
return ggml_cpy(ctx, k_cur, k_view);
614604
}
615605

616-
ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t ikv) const {
617-
auto * v = layers[ikv].v;
606+
ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
607+
const int32_t ikv = map_layer_ids.at(il);
618608

619-
const uint32_t il = layers[ikv].il;
609+
auto * v = layers[ikv].v;
620610

621611
const int64_t n_tokens = v_cur->ne[2];
622612

src/llama-kv-cache.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
#include "ggml-cpp.h"
99

10+
#include <map>
1011
#include <set>
1112
#include <vector>
1213

@@ -161,11 +162,11 @@ class llama_kv_cache_unified : public llama_kv_cache {
161162

162163
uint32_t get_n() const;
163164

164-
ggml_tensor * get_k(ggml_context * ctx, int32_t ikv) const;
165-
ggml_tensor * get_v(ggml_context * ctx, int32_t ikv) const;
165+
ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
166+
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
166167

167-
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t ikv) const;
168-
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t ikv) const;
168+
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const;
169+
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const;
169170

170171
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
171172
void set_input_kq_mask_swa(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
@@ -239,6 +240,9 @@ class llama_kv_cache_unified : public llama_kv_cache {
239240
std::vector<kv_cell> cells;
240241
std::vector<kv_layer> layers;
241242

243+
// model layer id -> KV cache layer id
244+
std::map<int32_t, int32_t> map_layer_ids;
245+
242246
// pending cell updates that are not yet committed
243247
struct {
244248
std::vector<slot_range> ranges;

0 commit comments

Comments
 (0)