Skip to content

Commit 2664a3d

Browse files
committed
graph : don't mutate the KV cache during defrag
ggml-ci
1 parent 24f9a30 commit 2664a3d

File tree

5 files changed

+189
-172
lines changed

5 files changed

+189
-172
lines changed

src/llama-context.cpp

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1769,32 +1769,34 @@ void llama_context_kv_self::kv_self_update() {
17691769
if (kv->do_defrag) {
17701770
LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
17711771

1772-
ggml_backend_sched_reset(sched.get());
1772+
if (kv->defrag_prepare(graph_max_nodes())) {
1773+
ggml_backend_sched_reset(sched.get());
17731774

1774-
auto * gf = graph_init();
1775+
auto * gf = graph_init();
17751776

1776-
model.build_graph_kv_self_defrag(
1777-
{
1778-
/*.ctx =*/ ctx_compute.get(),
1779-
/*.model =*/ model,
1780-
/*.cparams =*/ cparams,
1781-
/*.ubatch =*/ {},
1782-
/*.sched =*/ sched.get(),
1783-
/*.backend_cpu =*/ backend_cpu,
1784-
/*.backends =*/ backends,
1785-
/*.cvec =*/ nullptr,
1786-
/*.loras =*/ nullptr,
1787-
/*.memory =*/ nullptr,
1788-
/*.cross =*/ nullptr,
1789-
/*.n_outputs =*/ 0,
1790-
}, gf);
1777+
model.build_graph_kv_self_defrag(
1778+
{
1779+
/*.ctx =*/ ctx_compute.get(),
1780+
/*.model =*/ model,
1781+
/*.cparams =*/ cparams,
1782+
/*.ubatch =*/ {},
1783+
/*.sched =*/ sched.get(),
1784+
/*.backend_cpu =*/ backend_cpu,
1785+
/*.backends =*/ backends,
1786+
/*.cvec =*/ nullptr,
1787+
/*.loras =*/ nullptr,
1788+
/*.memory =*/ nullptr,
1789+
/*.cross =*/ nullptr,
1790+
/*.n_outputs =*/ 0,
1791+
}, gf);
17911792

1792-
ggml_backend_sched_alloc_graph(sched.get(), gf);
1793+
ggml_backend_sched_alloc_graph(sched.get(), gf);
17931794

1794-
// no input
1795-
//input_set({});
1795+
// no input
1796+
//input_set({});
17961797

1797-
graph_compute(gf, false);
1798+
graph_compute(gf, false);
1799+
}
17981800

17991801
kv->do_defrag = false;
18001802

src/llama-graph.cpp

Lines changed: 8 additions & 145 deletions
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,7 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
594594

595595
llm_graph_context::llm_graph_context(const llm_graph_params & params) :
596596
model (params.model),
597+
arch (model.arch),
597598
hparams (model.hparams),
598599
cparams (params.cparams),
599600
ubatch (params.ubatch),
@@ -633,13 +634,8 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
633634
res (std::make_unique<llm_graph_result>()) {
634635
}
635636

636-
// TODO: deduplicate with llama_context::graph_max_nodes()
637-
int32_t llm_graph_context::graph_max_nodes() const {
638-
return std::max<int32_t>(8192, 5*model.n_tensors());
639-
}
640-
641637
int64_t llm_graph_context::n_pos_per_token() const {
642-
return model.arch == LLM_ARCH_QWEN2VL ? 4 : 1;
638+
return arch == LLM_ARCH_QWEN2VL ? 4 : 1;
643639
}
644640

645641
// callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
@@ -1251,8 +1247,6 @@ ggml_tensor * llm_graph_context::build_attn_mha(
12511247

12521248
// TODO: replace hardcoded padding with ggml-provided padding
12531249
if (cparams.flash_attn && (n_kv % 256 == 0) && kq_b == nullptr) {
1254-
GGML_UNUSED(model);
1255-
12561250
GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet");
12571251

12581252
if (v_trans) {
@@ -1272,7 +1266,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
12721266
// while for some models F16 is enough, for others it is not, so we default to F32 here
12731267
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
12741268

1275-
if (model.arch == LLM_ARCH_GROK) {
1269+
if (arch == LLM_ARCH_GROK) {
12761270
// need to do the following:
12771271
// multiply by attn_output_multiplyer of 0.08838834764831845
12781272
// and then :
@@ -1483,7 +1477,7 @@ ggml_tensor * llm_graph_context::build_attn(
14831477
// TODO: improve
14841478
bool is_sliding = false;
14851479

1486-
switch (model.arch) {
1480+
switch (arch) {
14871481
case LLM_ARCH_COHERE2:
14881482
{
14891483
const int32_t sliding_window_pattern = 4;
@@ -2110,140 +2104,9 @@ void llm_graph_context::build_kv_self_shift(ggml_cgraph * gf) const {
21102104
}
21112105

21122106
void llm_graph_context::build_kv_self_defrag(ggml_cgraph * gf) const {
2113-
const llama_kv_cache_unified * kv_self_const = static_cast<const llama_kv_cache_unified *>(memory);
2114-
2115-
// TODO: avoid this
2116-
llama_kv_cache_unified * kv_self = const_cast<llama_kv_cache_unified *>(kv_self_const);
2117-
2118-
const uint32_t n_layer = hparams.n_layer;
2119-
2120-
const uint32_t n_kv = kv_self->cell_max();
2121-
const uint32_t n_used = kv_self->used;
2122-
2123-
assert(n_used <= n_kv);
2124-
2125-
//const int64_t t_start = ggml_time_us();
2126-
2127-
// number of cells moved
2128-
uint32_t n_moves = 0;
2129-
2130-
// each move requires 6*n_layer tensors (see graph_build_kv_self_defrag)
2131-
// - source view, destination view, copy operation
2132-
// - x2 for keys and values
2133-
//const uint32_t max_moves = max_nodes()/(6*n_layer);
2134-
// TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516
2135-
const uint32_t max_moves = (graph_max_nodes() - 2*n_layer)/(6*n_layer);
2136-
2137-
// determine which KV cells to move where
2138-
//
2139-
// cell i moves to ids[i]
2140-
//
2141-
// if ids[i] == i || ids[i] == n_kv, then cell i is not moved
2142-
//
2143-
std::vector<uint32_t> ids(n_kv, n_kv);
2144-
2145-
for (uint32_t i0 = 0; i0 < n_used; ++i0) {
2146-
const auto & cell0 = kv_self->cells[i0];
2147-
2148-
if (!cell0.is_empty()) {
2149-
ids[i0] = i0;
2150-
2151-
continue;
2152-
}
2153-
2154-
// found a hole - fill it with data from the end of the cache
2155-
2156-
uint32_t nh = 1;
2157-
2158-
// determine the size of the hole
2159-
while (i0 + nh < n_used && kv_self->cells[i0 + nh].is_empty()) {
2160-
nh++;
2161-
}
2162-
2163-
uint32_t nf = 0;
2164-
uint32_t is = n_kv - 1;
2165-
2166-
// starting from the end, find nh non-empty cells
2167-
for (; is > i0; --is) {
2168-
const auto & cell1 = kv_self->cells[is];
2169-
2170-
if (cell1.is_empty() || ids[is] != n_kv) {
2171-
continue;
2172-
}
2173-
2174-
// non-empty cell which is not yet moved
2175-
nf++;
2176-
2177-
if (nf == nh) {
2178-
break;
2179-
}
2180-
}
2181-
2182-
// this can only happen if `n_used` is not accurate, which would be a bug
2183-
GGML_ASSERT(nf == nh && "KV defrag bug: nf != nh");
2184-
2185-
nf = 0;
2186-
2187-
uint32_t i1 = is;
2188-
2189-
// are we moving a continuous block of memory?
2190-
bool cont = false;
2191-
2192-
// should we stop searching for the next move?
2193-
bool stop = false;
2194-
2195-
// go back and move the nf cells to the hole
2196-
for (; i1 < n_kv; ++i1) {
2197-
auto & cell1 = kv_self->cells[i1];
2198-
2199-
if (cell1.is_empty() || ids[i1] != n_kv) {
2200-
if (n_moves == max_moves) {
2201-
stop = true;
2202-
break;
2203-
}
2204-
2205-
cont = false;
2206-
continue;
2207-
}
2208-
2209-
// this cell goes to (i0 + nf)
2210-
ids[i1] = i0 + nf;
2211-
2212-
// move the cell meta data
2213-
kv_self->cells[i0 + nf] = cell1;
2214-
2215-
// clear the old cell and move the head there
2216-
cell1 = llama_kv_cell();
2217-
kv_self->head = n_used;
2218-
2219-
if (!cont) {
2220-
n_moves++;
2221-
cont = true;
2222-
}
2223-
2224-
nf++;
2225-
2226-
if (nf == nh) {
2227-
break;
2228-
}
2229-
}
2230-
2231-
if (stop || n_moves == max_moves) {
2232-
break;
2233-
}
2234-
2235-
//LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
2236-
2237-
i0 += nh - 1;
2238-
}
2239-
2240-
if (n_moves == 0) {
2241-
return;
2242-
}
2243-
2244-
//LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", n_moves);
2107+
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
22452108

2246-
//LLAMA_LOG_INFO("expected gf nodes: %u\n", 6*n_moves*n_layer);
2109+
const auto & ids = kv_self->defrag_info.ids;
22472110

22482111
#if 0
22492112
// CPU defrag
@@ -2424,8 +2287,8 @@ void llm_graph_context::build_pooling(ggml_cgraph * gf) const {
24242287

24252288
// classification head
24262289
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
2427-
GGML_ASSERT(model.cls != nullptr);
2428-
GGML_ASSERT(model.cls_b != nullptr);
2290+
GGML_ASSERT(model.cls != nullptr);
2291+
GGML_ASSERT(model.cls_b != nullptr);
24292292

24302293
cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls, inp), model.cls_b);
24312294
cur = ggml_tanh(ctx0, cur);

src/llama-graph.h

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -351,8 +351,12 @@ class llm_graph_result : public llm_graph_result_i {
351351
std::vector<llm_graph_input_ptr> inputs;
352352
};
353353

354+
//
355+
// llm_graph_context
356+
//
357+
354358
struct llm_graph_params {
355-
ggml_context * ctx;
359+
ggml_context * ctx;
356360

357361
const llama_model & model;
358362
const llama_cparams & cparams;
@@ -371,7 +375,10 @@ struct llm_graph_params {
371375
};
372376

373377
struct llm_graph_context {
374-
const llama_model & model; // TODO: remove reference to model
378+
const llama_model & model; // TODO: remove reference to model
379+
380+
const llm_arch arch;
381+
375382
const llama_hparams & hparams;
376383
const llama_cparams & cparams;
377384
const llama_ubatch & ubatch;
@@ -407,8 +414,9 @@ struct llm_graph_context {
407414

408415
ggml_context * ctx0 = nullptr;
409416

410-
// TODO: these are only used by the cb() call, so maybe we can avoid them in the future
411417
ggml_backend_sched * sched;
418+
419+
// TODO: these are only used by the cb() call, so maybe we can avoid them in the future
412420
ggml_backend * backend_cpu;
413421
const std::vector<ggml_backend_ptr> & backends;
414422

@@ -421,9 +429,6 @@ struct llm_graph_context {
421429

422430
llm_graph_context(const llm_graph_params & params);
423431

424-
// TODO: deduplicate with llama_context::graph_max_nodes()
425-
int32_t graph_max_nodes() const;
426-
427432
int64_t n_pos_per_token() const;
428433

429434
// callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)

0 commit comments

Comments
 (0)