Skip to content

Commit e5bfd55

Browse files
committed
cont : distinguish KV cache layers from model layers
ggml-ci
1 parent 06b6184 commit e5bfd55

File tree

2 files changed

+56
-54
lines changed

2 files changed

+56
-54
lines changed

src/llama-kv-cache.cpp

Lines changed: 47 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,11 @@ llama_kv_cache_unified::llama_kv_cache_unified(
3030
bool offload,
3131
uint32_t kv_size,
3232
uint32_t padding) : model(model), hparams(model.hparams), v_trans(v_trans), padding(padding) {
33-
const int32_t n_layer = hparams.n_layer;
34-
3533
has_shift = false;
3634
can_shift = true;
3735

3836
LLAMA_LOG_INFO("%s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d, padding = %d\n",
39-
__func__, kv_size, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, can_shift, padding);
37+
__func__, kv_size, ggml_type_name(type_k), ggml_type_name(type_v), hparams.n_layer, can_shift, padding);
4038

4139
GGML_ASSERT(kv_size % padding == 0 && "kv_size must be a multiple of padding");
4240

@@ -49,7 +47,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
4947
auto it = ctx_map.find(buft);
5048
if (it == ctx_map.end()) {
5149
ggml_init_params params = {
52-
/*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()),
50+
/*.mem_size =*/ size_t(2u*hparams.n_layer*ggml_tensor_overhead()),
5351
/*.mem_buffer =*/ NULL,
5452
/*.no_alloc =*/ true,
5553
};
@@ -73,26 +71,23 @@ llama_kv_cache_unified::llama_kv_cache_unified(
7371
used = 0;
7472

7573
cells.resize(kv_size);
76-
layers.resize(n_layer);
77-
78-
for (int i = 0; i < n_layer; i++) {
79-
auto & layer = layers[i];
8074

81-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
82-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
75+
for (uint32_t il = 0; il < hparams.n_layer; il++) {
76+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
77+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
8378

8479
const char * dev_name = "CPU";
8580

8681
ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type();
8782

8883
if (offload) {
89-
auto * dev = model.dev_layer(i);
84+
auto * dev = model.dev_layer(il);
9085
buft = ggml_backend_dev_buffer_type(dev);
9186

9287
dev_name = ggml_backend_dev_name(dev);
9388
}
9489

95-
LLAMA_LOG_DEBUG("%s: layer %3d: dev = %s\n", __func__, i, dev_name);
90+
LLAMA_LOG_DEBUG("%s: layer %3d: dev = %s\n", __func__, il, dev_name);
9691

9792
ggml_context * ctx = ctx_for_buft(buft);
9893
if (!ctx) {
@@ -104,7 +99,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
10499

105100
// TODO: enable
106101
#if 0
107-
if (hparams.is_swa(i)) {
102+
if (hparams.is_swa(il)) {
108103
k = ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, hparams.n_swa);
109104
v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, hparams.n_swa);
110105
} else {
@@ -116,11 +111,10 @@ llama_kv_cache_unified::llama_kv_cache_unified(
116111
v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, kv_size);
117112
#endif
118113

119-
ggml_format_name(k, "cache_k_l%d", i);
120-
ggml_format_name(v, "cache_v_l%d", i);
114+
ggml_format_name(k, "cache_k_l%d", il);
115+
ggml_format_name(v, "cache_v_l%d", il);
121116

122-
layer.k = k;
123-
layer.v = v;
117+
layers.push_back({ il, k, v });
124118
}
125119

126120
// allocate tensors and initialize the buffers to avoid NaNs in the padding
@@ -565,8 +559,10 @@ uint32_t llama_kv_cache_unified::get_n() const {
565559
return n;
566560
}
567561

568-
ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il) const {
569-
auto * k = layers[il].k;
562+
ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t ikv) const {
563+
auto * k = layers[ikv].k;
564+
565+
const uint32_t il = layers[ikv].il;
570566

571567
return ggml_view_3d(ctx, k,
572568
hparams.n_embd_head_k, hparams.n_head_kv(il), n,
@@ -575,8 +571,10 @@ ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il) cons
575571
0);
576572
}
577573

578-
ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il) const {
579-
auto * v = layers[il].v;
574+
ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t ikv) const {
575+
auto * v = layers[ikv].v;
576+
577+
const uint32_t il = layers[ikv].il;
580578

581579
if (!v_trans) {
582580
// note: v->nb[1] <= v->nb[2]
@@ -595,8 +593,10 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il) cons
595593
0);
596594
}
597595

598-
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
599-
auto * k = layers[il].k;
596+
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t ikv) const {
597+
auto * k = layers[ikv].k;
598+
599+
const uint32_t il = layers[ikv].il;
600600

601601
const int64_t n_tokens = k_cur->ne[2];
602602

@@ -607,8 +607,10 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
607607
return ggml_cpy(ctx, k_cur, k_view);
608608
}
609609

610-
ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
611-
auto * v = layers[il].v;
610+
ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t ikv) const {
611+
auto * v = layers[ikv].v;
612+
613+
const uint32_t il = layers[ikv].il;
612614

613615
const int64_t n_tokens = v_cur->ne[2];
614616

@@ -890,8 +892,6 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
890892
ggml_cgraph * gf) const {
891893
auto res = std::make_unique<llm_graph_result>();
892894

893-
const auto & n_layer = hparams.n_layer;
894-
895895
const auto & n_embd_head_k = hparams.n_embd_head_k;
896896
//const auto & n_embd_head_v = hparams.n_embd_head_v;
897897

@@ -904,8 +904,8 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
904904
inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cparams.n_ctx);
905905
ggml_set_input(inp->k_shift);
906906

907-
for (uint32_t il = 0; il < n_layer; ++il) {
908-
const auto & layer = layers[il];
907+
for (const auto & layer : layers) {
908+
const uint32_t il = layer.il;
909909

910910
const int64_t n_head_kv = hparams.n_head_kv(il);
911911
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
@@ -1028,8 +1028,8 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
10281028
nm++;
10291029
}
10301030

1031-
for (uint32_t il = 0; il < hparams.n_layer; ++il) { // NOLINT
1032-
const auto & layer = layers[il];
1031+
for (const auto & layer : layers) {
1032+
const uint32_t il = layer.il;
10331033

10341034
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
10351035
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
@@ -1084,7 +1084,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
10841084
}
10851085

10861086
bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
1087-
const uint32_t n_layer = hparams.n_layer;
1087+
const uint32_t n_layer = layers.size();
10881088

10891089
const uint32_t n_kv = cell_max();
10901090
const uint32_t n_used = used;
@@ -1309,7 +1309,7 @@ void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::
13091309

13101310
void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
13111311
const uint32_t v_trans = this->v_trans ? 1 : 0;
1312-
const uint32_t n_layer = hparams.n_layer;
1312+
const uint32_t n_layer = layers.size();
13131313

13141314
io.write(&v_trans, sizeof(v_trans));
13151315
io.write(&n_layer, sizeof(n_layer));
@@ -1318,8 +1318,8 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
13181318

13191319
// Iterate and write all the keys first, each row is a cell
13201320
// Get whole range at a time
1321-
for (uint32_t il = 0; il < n_layer; ++il) {
1322-
const auto & layer = layers[il];
1321+
for (const auto & layer : layers) {
1322+
const uint32_t il = layer.il;
13231323

13241324
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
13251325

@@ -1340,8 +1340,8 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
13401340
}
13411341

13421342
if (!v_trans) {
1343-
for (uint32_t il = 0; il < n_layer; ++il) {
1344-
const auto & layer = layers[il];
1343+
for (const auto & layer : layers) {
1344+
const uint32_t il = layer.il;
13451345

13461346
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
13471347

@@ -1364,8 +1364,8 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
13641364
// When v is transposed, we also need the element size and get the element ranges from each row
13651365
const uint32_t kv_size = size;
13661366

1367-
for (uint32_t il = 0; il < n_layer; ++il) {
1368-
const auto & layer = layers[il];
1367+
for (const auto & layer : layers) {
1368+
const uint32_t il = layer.il;
13691369

13701370
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
13711371

@@ -1485,8 +1485,8 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
14851485
io.read_to(&v_trans, sizeof(v_trans));
14861486
io.read_to(&n_layer, sizeof(n_layer));
14871487

1488-
if (n_layer != hparams.n_layer) {
1489-
LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer);
1488+
if (n_layer != layers.size()) {
1489+
LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size());
14901490
return false;
14911491
}
14921492
if (cell_count > size) {
@@ -1499,8 +1499,8 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
14991499
}
15001500

15011501
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
1502-
for (uint32_t il = 0; il < n_layer; ++il) {
1503-
const auto & layer = layers[il];
1502+
for (const auto & layer : layers) {
1503+
const uint32_t il = layer.il;
15041504

15051505
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
15061506

@@ -1529,8 +1529,8 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
15291529
}
15301530

15311531
if (!this->v_trans) {
1532-
for (uint32_t il = 0; il < n_layer; ++il) {
1533-
const auto & layer = layers[il];
1532+
for (const auto & layer : layers) {
1533+
const uint32_t il = layer.il;
15341534

15351535
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
15361536

@@ -1559,8 +1559,8 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
15591559
}
15601560
} else {
15611561
// For each layer, read the values for each cell (transposed)
1562-
for (uint32_t il = 0; il < n_layer; ++il) {
1563-
const auto & layer = layers[il];
1562+
for (const auto & layer : layers) {
1563+
const uint32_t il = layer.il;
15641564

15651565
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
15661566

src/llama-kv-cache.h

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -158,11 +158,11 @@ class llama_kv_cache_unified : public llama_kv_cache {
158158

159159
uint32_t get_n() const;
160160

161-
ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
162-
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
161+
ggml_tensor * get_k(ggml_context * ctx, int32_t ikv) const;
162+
ggml_tensor * get_v(ggml_context * ctx, int32_t ikv) const;
163163

164-
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const;
165-
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const;
164+
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t ikv) const;
165+
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t ikv) const;
166166

167167
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
168168
void set_input_kq_mask_swa(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
@@ -200,8 +200,10 @@ class llama_kv_cache_unified : public llama_kv_cache {
200200
};
201201

202202
struct kv_layer {
203-
ggml_tensor * k = nullptr;
204-
ggml_tensor * v = nullptr;
203+
uint32_t il; // layer index in the original model
204+
205+
ggml_tensor * k;
206+
ggml_tensor * v;
205207
};
206208

207209
bool has_shift = false;
@@ -229,7 +231,7 @@ class llama_kv_cache_unified : public llama_kv_cache {
229231
std::vector<ggml_context_ptr> ctxs;
230232
std::vector<ggml_backend_buffer_ptr> bufs;
231233

232-
std::vector<kv_cell> cells;
234+
std::vector<kv_cell> cells;
233235
std::vector<kv_layer> layers;
234236

235237
// pending cell updates that are not yet committed

0 commit comments

Comments
 (0)