@@ -37,8 +37,15 @@ llama_kv_cache::llama_kv_cache(
3737
3838 const uint32_t n_layer_kv = hparams.n_layer_kv ();
3939
40+ // define a comparator for the buft -> ctx map to ensure that the order is well-defined:
41+ struct ggml_backend_buft_comparator {
42+ bool operator ()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const {
43+ return ggml_backend_buft_name (lhs) < ggml_backend_buft_name (rhs);
44+ }
45+ };
46+ std::map<ggml_backend_buffer_type_t , ggml_context_ptr, ggml_backend_buft_comparator> ctx_map;
47+
4048 // create a context for each buffer type
41- std::map<ggml_backend_buffer_type_t , ggml_context *> ctx_map;
4249 auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
4350 auto it = ctx_map.find (buft);
4451 if (it == ctx_map.end ()) {
@@ -53,13 +60,12 @@ llama_kv_cache::llama_kv_cache(
5360 return nullptr ;
5461 }
5562
56- ctx_map[buft] = ctx;
57- ctxs.emplace_back (ctx);
63+ ctx_map.emplace (buft, ctx);
5864
5965 return ctx;
6066 }
6167
62- return it->second ;
68+ return it->second . get () ;
6369 };
6470
6571 GGML_ASSERT (n_stream == 1 || n_stream == n_seq_max);
@@ -167,19 +173,16 @@ llama_kv_cache::llama_kv_cache(
167173 }
168174
169175 // allocate tensors and initialize the buffers to avoid NaNs in the padding
170- for (auto it : ctx_map) {
171- auto * buft = it.first ;
172- auto * ctx = it.second ;
173-
174- ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft (ctx, buft);
176+ for (auto & [buft, ctx] : ctx_map) {
177+ ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft (ctx.get (), buft);
175178 if (!buf) {
176179 throw std::runtime_error (" failed to allocate buffer for kv cache" );
177180 }
178181
179182 LLAMA_LOG_INFO (" %s: %10s KV buffer size = %8.2f MiB\n " , __func__, ggml_backend_buffer_name (buf), ggml_backend_buffer_get_size (buf)/1024.0 /1024.0 );
180183
181184 ggml_backend_buffer_clear (buf, 0 );
182- bufs .emplace_back (buf);
185+ ctxs_bufs .emplace_back (std::move (ctx), buf);
183186 }
184187
185188 {
@@ -203,7 +206,7 @@ void llama_kv_cache::clear(bool data) {
203206 }
204207
205208 if (data) {
206- for (auto & buf : bufs ) {
209+ for (auto & [_, buf] : ctxs_bufs ) {
207210 ggml_backend_buffer_clear (buf.get (), 0 );
208211 }
209212 }
@@ -472,8 +475,8 @@ llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) const {
472475
473476std::map<ggml_backend_buffer_type_t , size_t > llama_kv_cache::memory_breakdown () const {
474477 std::map<ggml_backend_buffer_type_t , size_t > ret;
475- for (const ggml_backend_buffer_ptr & buf_ptr : bufs ) {
476- ret[ggml_backend_buffer_get_type (buf_ptr .get ())] += ggml_backend_buffer_get_size (buf_ptr .get ());
478+ for (const auto & [_, buf] : ctxs_bufs ) {
479+ ret[ggml_backend_buffer_get_type (buf .get ())] += ggml_backend_buffer_get_size (buf .get ());
477480 }
478481 return ret;
479482}
@@ -1298,7 +1301,7 @@ void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch
12981301size_t llama_kv_cache::total_size () const {
12991302 size_t size = 0 ;
13001303
1301- for (const auto & buf : bufs ) {
1304+ for (const auto & [_, buf] : ctxs_bufs ) {
13021305 size += ggml_backend_buffer_get_size (buf.get ());
13031306 }
13041307
0 commit comments