Skip to content

Commit 41deb58

Browse files
llama: consistent ctx <-> buf order for KV cache
1 parent dd62dcf commit 41deb58

File tree

4 files changed

+38
-32
lines changed

4 files changed

+38
-32
lines changed

src/llama-kv-cache.cpp

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,15 @@ llama_kv_cache::llama_kv_cache(
3737

3838
const uint32_t n_layer_kv = hparams.n_layer_kv();
3939

40+
// define a comparator for the buft -> ctx map to ensure that the order is well-defined:
41+
struct ggml_backend_buft_comparator {
42+
bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const {
43+
return ggml_backend_buft_name(lhs) < ggml_backend_buft_name(rhs);
44+
}
45+
};
46+
std::map<ggml_backend_buffer_type_t, ggml_context_ptr, ggml_backend_buft_comparator> ctx_map;
47+
4048
// create a context for each buffer type
41-
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
4249
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
4350
auto it = ctx_map.find(buft);
4451
if (it == ctx_map.end()) {
@@ -53,13 +60,12 @@ llama_kv_cache::llama_kv_cache(
5360
return nullptr;
5461
}
5562

56-
ctx_map[buft] = ctx;
57-
ctxs.emplace_back(ctx);
63+
ctx_map.emplace(buft, ctx);
5864

5965
return ctx;
6066
}
6167

62-
return it->second;
68+
return it->second.get();
6369
};
6470

6571
GGML_ASSERT(n_stream == 1 || n_stream == n_seq_max);
@@ -167,19 +173,16 @@ llama_kv_cache::llama_kv_cache(
167173
}
168174

169175
// allocate tensors and initialize the buffers to avoid NaNs in the padding
170-
for (auto it : ctx_map) {
171-
auto * buft = it.first;
172-
auto * ctx = it.second;
173-
174-
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
176+
for (auto & [buft, ctx] : ctx_map) {
177+
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft);
175178
if (!buf) {
176179
throw std::runtime_error("failed to allocate buffer for kv cache");
177180
}
178181

179182
LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
180183

181184
ggml_backend_buffer_clear(buf, 0);
182-
bufs.emplace_back(buf);
185+
ctxs_bufs.emplace_back(std::move(ctx), buf);
183186
}
184187

185188
{
@@ -203,7 +206,7 @@ void llama_kv_cache::clear(bool data) {
203206
}
204207

205208
if (data) {
206-
for (auto & buf : bufs) {
209+
for (auto & [_, buf] : ctxs_bufs) {
207210
ggml_backend_buffer_clear(buf.get(), 0);
208211
}
209212
}
@@ -472,8 +475,8 @@ llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) const {
472475

473476
std::map<ggml_backend_buffer_type_t, size_t> llama_kv_cache::memory_breakdown() const {
474477
std::map<ggml_backend_buffer_type_t, size_t> ret;
475-
for (const ggml_backend_buffer_ptr & buf_ptr : bufs) {
476-
ret[ggml_backend_buffer_get_type(buf_ptr.get())] += ggml_backend_buffer_get_size(buf_ptr.get());
478+
for (const auto & [_, buf] : ctxs_bufs) {
479+
ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get());
477480
}
478481
return ret;
479482
}
@@ -1298,7 +1301,7 @@ void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch
12981301
size_t llama_kv_cache::total_size() const {
12991302
size_t size = 0;
13001303

1301-
for (const auto & buf : bufs) {
1304+
for (const auto & [_, buf] : ctxs_bufs) {
13021305
size += ggml_backend_buffer_get_size(buf.get());
13031306
}
13041307

src/llama-kv-cache.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,8 @@ class llama_kv_cache : public llama_memory_i {
217217
// this is the SWA type of the cache - not to be confused with the model SWA type
218218
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
219219

220-
std::vector<ggml_context_ptr> ctxs;
221-
std::vector<ggml_backend_buffer_ptr> bufs;
220+
// ggml contexts for the KV cache along with the allocated backend buffers:
221+
std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs;
222222

223223
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
224224
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method

src/llama-memory-recurrent.cpp

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,15 @@ llama_memory_recurrent::llama_memory_recurrent(
3232
cells.clear();
3333
cells.resize(mem_size);
3434

35+
// define a comparator for the buft -> ctx map to ensure that the order is well-defined:
36+
struct ggml_backend_buft_comparator {
37+
bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const {
38+
return ggml_backend_buft_name(lhs) < ggml_backend_buft_name(rhs);
39+
}
40+
};
41+
std::map<ggml_backend_buffer_type_t, ggml_context_ptr, ggml_backend_buft_comparator> ctx_map;
42+
3543
// create a context for each buffer type
36-
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
3744
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
3845
auto it = ctx_map.find(buft);
3946
if (it == ctx_map.end()) {
@@ -48,13 +55,12 @@ llama_memory_recurrent::llama_memory_recurrent(
4855
return nullptr;
4956
}
5057

51-
ctx_map[buft] = ctx;
52-
ctxs.emplace_back(ctx);
58+
ctx_map.emplace(buft, ctx);
5359

5460
return ctx;
5561
}
5662

57-
return it->second;
63+
return it->second.get();
5864
};
5965

6066
r_l.resize(n_layer);
@@ -93,17 +99,14 @@ llama_memory_recurrent::llama_memory_recurrent(
9399
}
94100

95101
// allocate tensors and initialize the buffers to avoid NaNs in the padding
96-
for (auto it : ctx_map) {
97-
auto * buft = it.first;
98-
auto * ctx = it.second;
99-
100-
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
102+
for (auto & [buft, ctx] : ctx_map) {
103+
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft);
101104
if (!buf) {
102105
throw std::runtime_error("failed to allocate buffer for rs cache");
103106
}
104107
ggml_backend_buffer_clear(buf, 0);
105108
LLAMA_LOG_INFO("%s: %10s RS buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
106-
bufs.emplace_back(buf);
109+
ctxs_bufs.emplace_back(std::move(ctx), buf);
107110
}
108111

109112
{
@@ -129,7 +132,7 @@ void llama_memory_recurrent::clear(bool data) {
129132
used = 0;
130133

131134
if (data) {
132-
for (auto & buf : bufs) {
135+
for (auto & [_, buf] : ctxs_bufs) {
133136
ggml_backend_buffer_clear(buf.get(), 0);
134137
}
135138
}
@@ -364,8 +367,8 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
364367

365368
std::map<ggml_backend_buffer_type_t, size_t> llama_memory_recurrent::memory_breakdown() const {
366369
std::map<ggml_backend_buffer_type_t, size_t> ret;
367-
for (const ggml_backend_buffer_ptr & buf_ptr : bufs) {
368-
ret[ggml_backend_buffer_get_type(buf_ptr.get())] += ggml_backend_buffer_get_size(buf_ptr.get());
370+
for (const auto & [_, buf] : ctxs_bufs) {
371+
ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get());
369372
}
370373
return ret;
371374
}
@@ -662,7 +665,7 @@ bool llama_memory_recurrent::get_can_shift() const {
662665

663666
size_t llama_memory_recurrent::total_size() const {
664667
size_t size = 0;
665-
for (const auto & buf : bufs) {
668+
for (const auto & [_, buf] : ctxs_bufs) {
666669
size += ggml_backend_buffer_get_size(buf.get());
667670
}
668671

src/llama-memory-recurrent.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,8 @@ class llama_memory_recurrent : public llama_memory_i {
109109

110110
const uint32_t n_seq_max = 1;
111111

112-
std::vector<ggml_context_ptr> ctxs;
113-
std::vector<ggml_backend_buffer_ptr> bufs;
112+
// ggml contexts for the KV cache along with the allocated backend buffers:
113+
std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs;
114114

115115
size_t total_size() const;
116116

0 commit comments

Comments
 (0)