Skip to content

Commit a2d4b6f

Browse files
committed
llama: Ensure KV cache is fully defragmented.
Sometimes the KV cache requires defragmentation even without triggering the threshold heuristic. In this case, decoding will not being able to find a KV cache slot. This is particularly difficult for the caller to handle if it happens in between ubatches. To avoid this, we should immediately trigger a defrag. In addition, a heavily fragmented cache can require more than max_moves to defragment. Currently, we stop when we hit the limit but this can leave a cache that still does not have adequate space even after defragmentation is triggered. Instead, we should do multiple batches of processing until everything is complete.
1 parent 081b29b commit a2d4b6f

File tree

1 file changed

+46
-53
lines changed

1 file changed

+46
-53
lines changed

src/llama.cpp

Lines changed: 46 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -2955,6 +2955,13 @@ struct llama_kv_cache {
29552955
}
29562956
};
29572957

2958+
// block of KV slots to move when defragging
2959+
struct llama_kv_defrag_move {
2960+
uint32_t src;
2961+
uint32_t dst;
2962+
uint32_t len;
2963+
};
2964+
29582965
struct llama_control_vector {
29592966
std::vector<struct ggml_tensor *> tensors; // per layer
29602967
std::vector<ggml_context_ptr> ctxs;
@@ -10652,67 +10659,53 @@ struct llm_build_context {
1065210659
return gf;
1065310660
}
1065410661

10655-
struct ggml_cgraph * build_defrag(const std::vector<uint32_t> & ids) {
10662+
struct ggml_cgraph * build_defrag(const std::vector<struct llama_kv_defrag_move> & moves) {
1065610663
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
1065710664

10658-
for (uint32_t i = 0; i < ids.size(); ++i) {
10659-
const uint32_t id = ids[i];
10660-
10661-
if (i == id || id == ids.size()) {
10662-
continue;
10663-
}
10664-
10665-
uint32_t nm = 1;
10666-
10667-
while (i + nm < ids.size() && ids[i + nm] == id + nm) {
10668-
nm++;
10669-
}
10670-
10665+
for (const auto & move : moves) {
1067110666
for (int il = 0; il < n_layer; ++il) {
1067210667
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
1067310668
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1067410669

1067510670
ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self.k_l[il],
10676-
n_embd_k_gqa, nm,
10671+
n_embd_k_gqa, move.len,
1067710672
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
10678-
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*i));
10673+
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*move.src));
1067910674

1068010675
ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv_self.k_l[il],
10681-
n_embd_k_gqa, nm,
10676+
n_embd_k_gqa, move.len,
1068210677
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
10683-
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*id));
10678+
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*move.dst));
1068410679

1068510680
ggml_tensor * view_v_src;
1068610681
ggml_tensor * view_v_dst;
1068710682

1068810683
if (flash_attn) {
1068910684
// NOTE: the V cache is not transposed when using flash attention
1069010685
view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
10691-
n_embd_v_gqa, nm,
10686+
n_embd_v_gqa, move.len,
1069210687
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
10693-
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*i));
10688+
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*move.src));
1069410689

1069510690
view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
10696-
n_embd_v_gqa, nm,
10691+
n_embd_v_gqa, move.len,
1069710692
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
10698-
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*id));
10693+
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*move.dst));
1069910694
} else {
1070010695
view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
10701-
nm, n_embd_v_gqa,
10696+
move.len, n_embd_v_gqa,
1070210697
ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
10703-
ggml_row_size(kv_self.v_l[il]->type, i));
10698+
ggml_row_size(kv_self.v_l[il]->type, move.src));
1070410699

1070510700
view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
10706-
nm, n_embd_v_gqa,
10701+
move.len, n_embd_v_gqa,
1070710702
ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
10708-
ggml_row_size(kv_self.v_l[il]->type, id));
10703+
ggml_row_size(kv_self.v_l[il]->type, move.dst));
1070910704
}
1071010705

1071110706
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst));
1071210707
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst));
1071310708
}
10714-
10715-
i += nm - 1;
1071610709
}
1071710710

1071810711
//LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
@@ -16944,7 +16937,7 @@ struct llm_build_context {
1694416937
}
1694516938
};
1694616939

16947-
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
16940+
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<struct llama_kv_defrag_move> & moves) {
1694816941
llama_ubatch dummy = {};
1694916942
dummy.equal_seqs = true;
1695016943

@@ -16954,7 +16947,7 @@ static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const
1695416947

1695516948
llm.init();
1695616949

16957-
struct ggml_cgraph * result = llm.build_defrag(ids);
16950+
struct ggml_cgraph * result = llm.build_defrag(moves);
1695816951

1695916952
llm.free();
1696016953

@@ -17957,7 +17950,12 @@ static int llama_decode_internal(
1795717950
kv_self.head = 0;
1795817951
}
1795917952

17960-
const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
17953+
auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
17954+
if (!slot) {
17955+
llama_kv_cache_defrag(kv_self);
17956+
llama_kv_cache_update(&lctx);
17957+
slot = llama_kv_cache_find_slot(kv_self, ubatch);
17958+
}
1796117959
if (!slot) {
1796217960
return 1;
1796317961
}
@@ -18359,8 +18357,8 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
1835918357

1836018358
//const int64_t t_start = ggml_time_us();
1836118359

18362-
// number of cells moved
18363-
uint32_t n_moves = 0;
18360+
// groups of cells moved
18361+
std::vector<struct llama_kv_defrag_move> moves;
1836418362

1836518363
// each move requires 6*n_layer tensors (see build_defrag)
1836618364
// - source view, destination view, copy operation
@@ -18424,19 +18422,11 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
1842418422
// are we moving a continuous block of memory?
1842518423
bool cont = false;
1842618424

18427-
// should we stop searching for the next move?
18428-
bool stop = false;
18429-
1843018425
// go back and move the nf cells to the hole
1843118426
for (; i1 < n_kv; ++i1) {
1843218427
auto & cell1 = kv_self.cells[i1];
1843318428

1843418429
if (cell1.is_empty() || ids[i1] != n_kv) {
18435-
if (n_moves == max_moves) {
18436-
stop = true;
18437-
break;
18438-
}
18439-
1844018430
cont = false;
1844118431
continue;
1844218432
}
@@ -18452,8 +18442,10 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
1845218442
kv_self.head = n_used;
1845318443

1845418444
if (!cont) {
18455-
n_moves++;
18445+
moves.push_back({i1, i0 + nf, 1});
1845618446
cont = true;
18447+
} else {
18448+
moves.back().len++;
1845718449
}
1845818450

1845918451
nf++;
@@ -18463,22 +18455,16 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
1846318455
}
1846418456
}
1846518457

18466-
if (stop || n_moves == max_moves) {
18467-
break;
18468-
}
18469-
1847018458
//LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
1847118459

1847218460
i0 += nh - 1;
1847318461
}
1847418462

18475-
if (n_moves == 0) {
18463+
if (moves.size() == 0) {
1847618464
return;
1847718465
}
1847818466

18479-
//LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", n_moves);
18480-
18481-
//LLAMA_LOG_INFO("expected gf nodes: %u\n", 6*n_moves*n_layer);
18467+
//LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", moves.size());
1848218468

1848318469
#if 0
1848418470
// CPU defrag
@@ -18553,11 +18539,18 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
1855318539
#else
1855418540
// ggml_graph defrag
1855518541

18556-
ggml_backend_sched_reset(lctx.sched.get());
18542+
for (std::size_t i = 0; i < moves.size(); i += max_moves) {
18543+
std::vector<struct llama_kv_defrag_move> chunk;
18544+
auto end = std::min(i + max_moves, moves.size());
18545+
chunk.assign(moves.begin() + i, moves.begin() + end);
1855718546

18558-
ggml_cgraph * gf = llama_build_graph_defrag(lctx, ids);
18547+
ggml_backend_sched_reset(lctx.sched.get());
18548+
18549+
//LLAMA_LOG_INFO("expected gf nodes: %u\n", 6*chunk.size()*n_layer);
18550+
ggml_cgraph * gf = llama_build_graph_defrag(lctx, chunk);
1855918551

18560-
llama_graph_compute(lctx, gf, lctx.cparams.n_threads, lctx.threadpool);
18552+
llama_graph_compute(lctx, gf, lctx.cparams.n_threads, lctx.threadpool);
18553+
}
1856118554
#endif
1856218555

1856318556
//const int64_t t_end = ggml_time_us();

0 commit comments

Comments
 (0)