Skip to content

Commit d3aaab5

Browse files
jessegrossNexesenex
authored andcommitted
llama: Ensure KV cache is fully defragmented.
Sometimes the KV cache requires defragmentation even without triggering the threshold heuristic. In this case, decoding will not being able to find a KV cache slot. This is particularly difficult for the caller to handle if it happens in between ubatches. To avoid this, we should immediately trigger a defrag. In addition, a heavily fragmented cache can require more than max_moves to defragment. Currently, we stop when we hit the limit but this can leave a cache that still does not have adequate space even after defragmentation is triggered. Instead, we should do multiple batches of processing until everything is complete.
1 parent 06cc123 commit d3aaab5

File tree

1 file changed

+46
-53
lines changed

1 file changed

+46
-53
lines changed

src/llama.cpp

Lines changed: 46 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -3070,6 +3070,13 @@ struct llama_kv_cache {
30703070
}
30713071
};
30723072

3073+
// block of KV slots to move when defragging
3074+
struct llama_kv_defrag_move {
3075+
uint32_t src;
3076+
uint32_t dst;
3077+
uint32_t len;
3078+
};
3079+
30733080
struct llama_control_vector {
30743081
std::vector<struct ggml_tensor *> tensors; // per layer
30753082
std::vector<ggml_context_ptr> ctxs;
@@ -11239,67 +11246,53 @@ struct llm_build_context {
1123911246
return gf;
1124011247
}
1124111248

11242-
struct ggml_cgraph * build_defrag(const std::vector<uint32_t> & ids) {
11249+
struct ggml_cgraph * build_defrag(const std::vector<struct llama_kv_defrag_move> & moves) {
1124311250
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
1124411251

11245-
for (uint32_t i = 0; i < ids.size(); ++i) {
11246-
const uint32_t id = ids[i];
11247-
11248-
if (i == id || id == ids.size()) {
11249-
continue;
11250-
}
11251-
11252-
uint32_t nm = 1;
11253-
11254-
while (i + nm < ids.size() && ids[i + nm] == id + nm) {
11255-
nm++;
11256-
}
11257-
11252+
for (const auto & move : moves) {
1125811253
for (int il = 0; il < n_layer; ++il) {
1125911254
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
1126011255
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1126111256

1126211257
ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self.k_l[il],
11263-
n_embd_k_gqa, nm,
11258+
n_embd_k_gqa, move.len,
1126411259
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
11265-
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*i));
11260+
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*move.src));
1126611261

1126711262
ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv_self.k_l[il],
11268-
n_embd_k_gqa, nm,
11263+
n_embd_k_gqa, move.len,
1126911264
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
11270-
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*id));
11265+
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*move.dst));
1127111266

1127211267
ggml_tensor * view_v_src;
1127311268
ggml_tensor * view_v_dst;
1127411269

1127511270
if (flash_attn) {
1127611271
// NOTE: the V cache is not transposed when using flash attention
1127711272
view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
11278-
n_embd_v_gqa, nm,
11273+
n_embd_v_gqa, move.len,
1127911274
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
11280-
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*i));
11275+
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*move.src));
1128111276

1128211277
view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
11283-
n_embd_v_gqa, nm,
11278+
n_embd_v_gqa, move.len,
1128411279
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
11285-
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*id));
11280+
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*move.dst));
1128611281
} else {
1128711282
view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
11288-
nm, n_embd_v_gqa,
11283+
move.len, n_embd_v_gqa,
1128911284
ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
11290-
ggml_row_size(kv_self.v_l[il]->type, i));
11285+
ggml_row_size(kv_self.v_l[il]->type, move.src));
1129111286

1129211287
view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
11293-
nm, n_embd_v_gqa,
11288+
move.len, n_embd_v_gqa,
1129411289
ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
11295-
ggml_row_size(kv_self.v_l[il]->type, id));
11290+
ggml_row_size(kv_self.v_l[il]->type, move.dst));
1129611291
}
1129711292

1129811293
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst));
1129911294
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst));
1130011295
}
11301-
11302-
i += nm - 1;
1130311296
}
1130411297

1130511298
//LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
@@ -17855,7 +17848,7 @@ struct llm_build_context {
1785517848
}
1785617849
};
1785717850

17858-
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
17851+
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<struct llama_kv_defrag_move> & moves) {
1785917852
llama_ubatch dummy = {};
1786017853
dummy.equal_seqs = true;
1786117854

@@ -17865,7 +17858,7 @@ static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const
1786517858

1786617859
llm.init();
1786717860

17868-
struct ggml_cgraph * result = llm.build_defrag(ids);
17861+
struct ggml_cgraph * result = llm.build_defrag(moves);
1786917862

1787017863
llm.free();
1787117864

@@ -18881,7 +18874,12 @@ static int llama_decode_internal(
1888118874
kv_self.head = 0;
1888218875
}
1888318876

18884-
const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
18877+
auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
18878+
if (!slot) {
18879+
llama_kv_cache_defrag(kv_self);
18880+
llama_kv_cache_update(&lctx);
18881+
slot = llama_kv_cache_find_slot(kv_self, ubatch);
18882+
}
1888518883
if (!slot) {
1888618884
return 1;
1888718885
}
@@ -19284,8 +19282,8 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
1928419282

1928519283
//const int64_t t_start = ggml_time_us();
1928619284

19287-
// number of cells moved
19288-
uint32_t n_moves = 0;
19285+
// groups of cells moved
19286+
std::vector<struct llama_kv_defrag_move> moves;
1928919287

1929019288
// each move requires 6*n_layer tensors (see build_defrag)
1929119289
// - source view, destination view, copy operation
@@ -19349,19 +19347,11 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
1934919347
// are we moving a continuous block of memory?
1935019348
bool cont = false;
1935119349

19352-
// should we stop searching for the next move?
19353-
bool stop = false;
19354-
1935519350
// go back and move the nf cells to the hole
1935619351
for (; i1 < n_kv; ++i1) {
1935719352
auto & cell1 = kv_self.cells[i1];
1935819353

1935919354
if (cell1.is_empty() || ids[i1] != n_kv) {
19360-
if (n_moves == max_moves) {
19361-
stop = true;
19362-
break;
19363-
}
19364-
1936519355
cont = false;
1936619356
continue;
1936719357
}
@@ -19377,8 +19367,10 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
1937719367
kv_self.head = n_used;
1937819368

1937919369
if (!cont) {
19380-
n_moves++;
19370+
moves.push_back({i1, i0 + nf, 1});
1938119371
cont = true;
19372+
} else {
19373+
moves.back().len++;
1938219374
}
1938319375

1938419376
nf++;
@@ -19388,22 +19380,16 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
1938819380
}
1938919381
}
1939019382

19391-
if (stop || n_moves == max_moves) {
19392-
break;
19393-
}
19394-
1939519383
//LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
1939619384

1939719385
i0 += nh - 1;
1939819386
}
1939919387

19400-
if (n_moves == 0) {
19388+
if (moves.size() == 0) {
1940119389
return;
1940219390
}
1940319391

19404-
//LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", n_moves);
19405-
19406-
//LLAMA_LOG_INFO("expected gf nodes: %u\n", 6*n_moves*n_layer);
19392+
//LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", moves.size());
1940719393

1940819394
#if 0
1940919395
// CPU defrag
@@ -19478,11 +19464,18 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
1947819464
#else
1947919465
// ggml_graph defrag
1948019466

19481-
ggml_backend_sched_reset(lctx.sched.get());
19467+
for (std::size_t i = 0; i < moves.size(); i += max_moves) {
19468+
std::vector<struct llama_kv_defrag_move> chunk;
19469+
auto end = std::min(i + max_moves, moves.size());
19470+
chunk.assign(moves.begin() + i, moves.begin() + end);
1948219471

19483-
ggml_cgraph * gf = llama_build_graph_defrag(lctx, ids);
19472+
ggml_backend_sched_reset(lctx.sched.get());
19473+
19474+
//LLAMA_LOG_INFO("expected gf nodes: %u\n", 6*chunk.size()*n_layer);
19475+
ggml_cgraph * gf = llama_build_graph_defrag(lctx, chunk);
1948419476

19485-
llama_graph_compute(lctx, gf, lctx.cparams.n_threads, lctx.threadpool);
19477+
llama_graph_compute(lctx, gf, lctx.cparams.n_threads, lctx.threadpool);
19478+
}
1948619479
#endif
1948719480

1948819481
//const int64_t t_end = ggml_time_us();

0 commit comments

Comments
 (0)