Skip to content

Commit f355522

Browse files
llama: consistent ctx <-> buf order for KV cache
1 parent dd62dcf commit f355522

File tree

6 files changed

+35
-39
lines changed

6 files changed

+35
-39
lines changed

ggml/include/ggml-cpp.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
#include "ggml-alloc.h"
99
#include "ggml-backend.h"
1010
#include "gguf.h"
11+
12+
#include <map>
1113
#include <memory>
1214

1315
// Smart pointers for ggml types
@@ -37,3 +39,11 @@ typedef std::unique_ptr<ggml_backend, ggml_backend_deleter> ggml_b
3739
typedef std::unique_ptr<ggml_backend_buffer, ggml_backend_buffer_deleter> ggml_backend_buffer_ptr;
3840
typedef std::unique_ptr<ggml_backend_event, ggml_backend_event_deleter> ggml_backend_event_ptr;
3941
typedef std::unique_ptr<ggml_backend_sched, ggml_backend_sched_deleter> ggml_backend_sched_ptr;
42+
43+
// define a comparator for the buft -> ctx map to ensure that the order is well-defined:
44+
struct ggml_backend_buft_comparator {
45+
bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const {
46+
return ggml_backend_buft_name(lhs) < ggml_backend_buft_name(rhs);
47+
}
48+
};
49+
typedef std::map<ggml_backend_buffer_type_t, ggml_context_ptr, ggml_backend_buft_comparator> buft_ctx_map_t;

src/llama-kv-cache.cpp

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ llama_kv_cache::llama_kv_cache(
3838
const uint32_t n_layer_kv = hparams.n_layer_kv();
3939

4040
// create a context for each buffer type
41-
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
41+
buft_ctx_map_t ctx_map;
4242
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
4343
auto it = ctx_map.find(buft);
4444
if (it == ctx_map.end()) {
@@ -53,13 +53,12 @@ llama_kv_cache::llama_kv_cache(
5353
return nullptr;
5454
}
5555

56-
ctx_map[buft] = ctx;
57-
ctxs.emplace_back(ctx);
56+
ctx_map.emplace(buft, ctx);
5857

5958
return ctx;
6059
}
6160

62-
return it->second;
61+
return it->second.get();
6362
};
6463

6564
GGML_ASSERT(n_stream == 1 || n_stream == n_seq_max);
@@ -167,19 +166,16 @@ llama_kv_cache::llama_kv_cache(
167166
}
168167

169168
// allocate tensors and initialize the buffers to avoid NaNs in the padding
170-
for (auto it : ctx_map) {
171-
auto * buft = it.first;
172-
auto * ctx = it.second;
173-
174-
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
169+
for (auto & [buft, ctx] : ctx_map) {
170+
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft);
175171
if (!buf) {
176172
throw std::runtime_error("failed to allocate buffer for kv cache");
177173
}
178174

179175
LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
180176

181177
ggml_backend_buffer_clear(buf, 0);
182-
bufs.emplace_back(buf);
178+
ctxs_bufs.emplace_back(std::move(ctx), buf);
183179
}
184180

185181
{
@@ -203,7 +199,7 @@ void llama_kv_cache::clear(bool data) {
203199
}
204200

205201
if (data) {
206-
for (auto & buf : bufs) {
202+
for (auto & [_, buf] : ctxs_bufs) {
207203
ggml_backend_buffer_clear(buf.get(), 0);
208204
}
209205
}
@@ -472,8 +468,8 @@ llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) const {
472468

473469
std::map<ggml_backend_buffer_type_t, size_t> llama_kv_cache::memory_breakdown() const {
474470
std::map<ggml_backend_buffer_type_t, size_t> ret;
475-
for (const ggml_backend_buffer_ptr & buf_ptr : bufs) {
476-
ret[ggml_backend_buffer_get_type(buf_ptr.get())] += ggml_backend_buffer_get_size(buf_ptr.get());
471+
for (const auto & [_, buf] : ctxs_bufs) {
472+
ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get());
477473
}
478474
return ret;
479475
}
@@ -1298,7 +1294,7 @@ void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch
12981294
size_t llama_kv_cache::total_size() const {
12991295
size_t size = 0;
13001296

1301-
for (const auto & buf : bufs) {
1297+
for (const auto & [_, buf] : ctxs_bufs) {
13021298
size += ggml_backend_buffer_get_size(buf.get());
13031299
}
13041300

src/llama-kv-cache.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,8 @@ class llama_kv_cache : public llama_memory_i {
217217
// this is the SWA type of the cache - not to be confused with the model SWA type
218218
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
219219

220-
std::vector<ggml_context_ptr> ctxs;
221-
std::vector<ggml_backend_buffer_ptr> bufs;
220+
// ggml contexts for the KV cache along with the allocated backend buffers:
221+
std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs;
222222

223223
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
224224
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method

src/llama-memory-recurrent.cpp

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ llama_memory_recurrent::llama_memory_recurrent(
3333
cells.resize(mem_size);
3434

3535
// create a context for each buffer type
36-
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
36+
buft_ctx_map_t ctx_map;
3737
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
3838
auto it = ctx_map.find(buft);
3939
if (it == ctx_map.end()) {
@@ -48,13 +48,12 @@ llama_memory_recurrent::llama_memory_recurrent(
4848
return nullptr;
4949
}
5050

51-
ctx_map[buft] = ctx;
52-
ctxs.emplace_back(ctx);
51+
ctx_map.emplace(buft, ctx);
5352

5453
return ctx;
5554
}
5655

57-
return it->second;
56+
return it->second.get();
5857
};
5958

6059
r_l.resize(n_layer);
@@ -93,17 +92,14 @@ llama_memory_recurrent::llama_memory_recurrent(
9392
}
9493

9594
// allocate tensors and initialize the buffers to avoid NaNs in the padding
96-
for (auto it : ctx_map) {
97-
auto * buft = it.first;
98-
auto * ctx = it.second;
99-
100-
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
95+
for (auto & [buft, ctx] : ctx_map) {
96+
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft);
10197
if (!buf) {
10298
throw std::runtime_error("failed to allocate buffer for rs cache");
10399
}
104100
ggml_backend_buffer_clear(buf, 0);
105101
LLAMA_LOG_INFO("%s: %10s RS buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
106-
bufs.emplace_back(buf);
102+
ctxs_bufs.emplace_back(std::move(ctx), buf);
107103
}
108104

109105
{
@@ -129,7 +125,7 @@ void llama_memory_recurrent::clear(bool data) {
129125
used = 0;
130126

131127
if (data) {
132-
for (auto & buf : bufs) {
128+
for (auto & [_, buf] : ctxs_bufs) {
133129
ggml_backend_buffer_clear(buf.get(), 0);
134130
}
135131
}
@@ -364,8 +360,8 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
364360

365361
std::map<ggml_backend_buffer_type_t, size_t> llama_memory_recurrent::memory_breakdown() const {
366362
std::map<ggml_backend_buffer_type_t, size_t> ret;
367-
for (const ggml_backend_buffer_ptr & buf_ptr : bufs) {
368-
ret[ggml_backend_buffer_get_type(buf_ptr.get())] += ggml_backend_buffer_get_size(buf_ptr.get());
363+
for (const auto & [_, buf] : ctxs_bufs) {
364+
ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get());
369365
}
370366
return ret;
371367
}
@@ -662,7 +658,7 @@ bool llama_memory_recurrent::get_can_shift() const {
662658

663659
size_t llama_memory_recurrent::total_size() const {
664660
size_t size = 0;
665-
for (const auto & buf : bufs) {
661+
for (const auto & [_, buf] : ctxs_bufs) {
666662
size += ggml_backend_buffer_get_size(buf.get());
667663
}
668664

src/llama-memory-recurrent.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,8 @@ class llama_memory_recurrent : public llama_memory_i {
109109

110110
const uint32_t n_seq_max = 1;
111111

112-
std::vector<ggml_context_ptr> ctxs;
113-
std::vector<ggml_backend_buffer_ptr> bufs;
112+
// ggml contexts for the KV cache along with the allocated backend buffers:
113+
std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs;
114114

115115
size_t total_size() const;
116116

src/llama-model.cpp

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2229,13 +2229,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
22292229
max_n_tensors += n_layer*2; // duplicated rope freq tensors
22302230
const size_t ctx_size = ggml_tensor_overhead()*max_n_tensors;
22312231

2232-
// define a comparator for the buft -> ctx map to ensure that the order is well-defined:
2233-
struct ggml_backend_buft_comparator {
2234-
bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const {
2235-
return ggml_backend_buft_name(lhs) < ggml_backend_buft_name(rhs);
2236-
}
2237-
};
2238-
std::map<ggml_backend_buffer_type_t, ggml_context_ptr, ggml_backend_buft_comparator> ctx_map;
2232+
buft_ctx_map_t ctx_map;
22392233

22402234
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
22412235
auto it = ctx_map.find(buft);

0 commit comments

Comments
 (0)