From 98c7edd81817aaebcd217427c6b0f95209cada71 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Thu, 23 Oct 2025 20:49:57 +0200 Subject: [PATCH] llama: consistent ctx <-> buf order for KV cache --- src/llama-kv-cache.cpp | 32 ++++++++++++++++++-------------- src/llama-kv-cache.h | 4 ++-- src/llama-memory-recurrent.cpp | 32 ++++++++++++++++++-------------- src/llama-memory-recurrent.h | 4 ++-- src/llama-model.cpp | 2 +- 5 files changed, 41 insertions(+), 33 deletions(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 736693e174527..add74391f0c47 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -37,8 +38,15 @@ llama_kv_cache::llama_kv_cache( const uint32_t n_layer_kv = hparams.n_layer_kv(); + // define a comparator for the buft -> ctx map to ensure that the order is well-defined: + struct ggml_backend_buft_comparator { + bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const { + return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0; + } + }; + std::map ctx_map; + // create a context for each buffer type - std::map ctx_map; auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { auto it = ctx_map.find(buft); if (it == ctx_map.end()) { @@ -53,13 +61,12 @@ llama_kv_cache::llama_kv_cache( return nullptr; } - ctx_map[buft] = ctx; - ctxs.emplace_back(ctx); + ctx_map.emplace(buft, ctx); return ctx; } - return it->second; + return it->second.get(); }; GGML_ASSERT(n_stream == 1 || n_stream == n_seq_max); @@ -167,11 +174,8 @@ llama_kv_cache::llama_kv_cache( } // allocate tensors and initialize the buffers to avoid NaNs in the padding - for (auto it : ctx_map) { - auto * buft = it.first; - auto * ctx = it.second; - - ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + for (auto & [buft, ctx] : ctx_map) { + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft); if (!buf) { throw std::runtime_error("failed to allocate buffer for kv cache"); } @@ -179,7 +183,7 @@ llama_kv_cache::llama_kv_cache( LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); ggml_backend_buffer_clear(buf, 0); - bufs.emplace_back(buf); + ctxs_bufs.emplace_back(std::move(ctx), buf); } { @@ -203,7 +207,7 @@ void llama_kv_cache::clear(bool data) { } if (data) { - for (auto & buf : bufs) { + for (auto & [_, buf] : ctxs_bufs) { ggml_backend_buffer_clear(buf.get(), 0); } } @@ -472,8 +476,8 @@ llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) const { std::map llama_kv_cache::memory_breakdown() const { std::map ret; - for (const ggml_backend_buffer_ptr & buf_ptr : bufs) { - ret[ggml_backend_buffer_get_type(buf_ptr.get())] += ggml_backend_buffer_get_size(buf_ptr.get()); + for (const auto & [_, buf] : ctxs_bufs) { + ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get()); } return ret; } @@ -1298,7 +1302,7 @@ void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch size_t llama_kv_cache::total_size() const { size_t size = 0; - for (const auto & buf : bufs) { + for (const auto & [_, buf] : ctxs_bufs) { size += ggml_backend_buffer_get_size(buf.get()); } diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 85f0663d8c1d4..150e282596255 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -217,8 +217,8 @@ class llama_kv_cache : public llama_memory_i { // this is the SWA type of the cache - not to be confused with the model SWA type const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; - std::vector ctxs; - std::vector bufs; + // ggml contexts for the KV cache along with the allocated backend buffers: + std::vector> ctxs_bufs; // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot()) // note: this is not part of the KV state and it's only used to speed-up the find_slot() method diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index d67f5a5f47b87..276e1697d466c 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -7,6 +7,7 @@ #include #include +#include #include #include #include @@ -32,8 +33,15 @@ llama_memory_recurrent::llama_memory_recurrent( cells.clear(); cells.resize(mem_size); + // define a comparator for the buft -> ctx map to ensure that the order is well-defined: + struct ggml_backend_buft_comparator { + bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const { + return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0; + } + }; + std::map ctx_map; + // create a context for each buffer type - std::map ctx_map; auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { auto it = ctx_map.find(buft); if (it == ctx_map.end()) { @@ -48,13 +56,12 @@ llama_memory_recurrent::llama_memory_recurrent( return nullptr; } - ctx_map[buft] = ctx; - ctxs.emplace_back(ctx); + ctx_map.emplace(buft, ctx); return ctx; } - return it->second; + return it->second.get(); }; r_l.resize(n_layer); @@ -93,17 +100,14 @@ llama_memory_recurrent::llama_memory_recurrent( } // allocate tensors and initialize the buffers to avoid NaNs in the padding - for (auto it : ctx_map) { - auto * buft = it.first; - auto * ctx = it.second; - - ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + for (auto & [buft, ctx] : ctx_map) { + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft); if (!buf) { throw std::runtime_error("failed to allocate buffer for rs cache"); } ggml_backend_buffer_clear(buf, 0); LLAMA_LOG_INFO("%s: %10s RS buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); - bufs.emplace_back(buf); + ctxs_bufs.emplace_back(std::move(ctx), buf); } { @@ -129,7 +133,7 @@ void llama_memory_recurrent::clear(bool data) { used = 0; if (data) { - for (auto & buf : bufs) { + for (auto & [_, buf] : ctxs_bufs) { ggml_backend_buffer_clear(buf.get(), 0); } } @@ -364,8 +368,8 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const { std::map llama_memory_recurrent::memory_breakdown() const { std::map ret; - for (const ggml_backend_buffer_ptr & buf_ptr : bufs) { - ret[ggml_backend_buffer_get_type(buf_ptr.get())] += ggml_backend_buffer_get_size(buf_ptr.get()); + for (const auto & [_, buf] : ctxs_bufs) { + ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get()); } return ret; } @@ -662,7 +666,7 @@ bool llama_memory_recurrent::get_can_shift() const { size_t llama_memory_recurrent::total_size() const { size_t size = 0; - for (const auto & buf : bufs) { + for (const auto & [_, buf] : ctxs_bufs) { size += ggml_backend_buffer_get_size(buf.get()); } diff --git a/src/llama-memory-recurrent.h b/src/llama-memory-recurrent.h index 077c6e3ce938d..47f01d7391248 100644 --- a/src/llama-memory-recurrent.h +++ b/src/llama-memory-recurrent.h @@ -109,8 +109,8 @@ class llama_memory_recurrent : public llama_memory_i { const uint32_t n_seq_max = 1; - std::vector ctxs; - std::vector bufs; + // ggml contexts for the KV cache along with the allocated backend buffers: + std::vector> ctxs_bufs; size_t total_size() const; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 05e467180089e..bb83a04e96055 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2231,7 +2231,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // define a comparator for the buft -> ctx map to ensure that the order is well-defined: struct ggml_backend_buft_comparator { bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const { - return ggml_backend_buft_name(lhs) < ggml_backend_buft_name(rhs); + return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0; } }; std::map ctx_map;