@@ -38,7 +38,7 @@ llama_kv_cache::llama_kv_cache(
3838 const uint32_t n_layer_kv = hparams.n_layer_kv ();
3939
4040 // create a context for each buffer type
41- std::map< ggml_backend_buffer_type_t , ggml_context *> ctx_map;
41+ buft_ctx_map_t ctx_map;
4242 auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
4343 auto it = ctx_map.find (buft);
4444 if (it == ctx_map.end ()) {
@@ -53,13 +53,12 @@ llama_kv_cache::llama_kv_cache(
5353 return nullptr ;
5454 }
5555
56- ctx_map[buft] = ctx;
57- ctxs.emplace_back (ctx);
56+ ctx_map.emplace (buft, ctx);
5857
5958 return ctx;
6059 }
6160
62- return it->second ;
61+ return it->second . get () ;
6362 };
6463
6564 GGML_ASSERT (n_stream == 1 || n_stream == n_seq_max);
@@ -167,19 +166,16 @@ llama_kv_cache::llama_kv_cache(
167166 }
168167
169168 // allocate tensors and initialize the buffers to avoid NaNs in the padding
170- for (auto it : ctx_map) {
171- auto * buft = it.first ;
172- auto * ctx = it.second ;
173-
174- ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft (ctx, buft);
169+ for (auto & [buft, ctx] : ctx_map) {
170+ ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft (ctx.get (), buft);
175171 if (!buf) {
176172 throw std::runtime_error (" failed to allocate buffer for kv cache" );
177173 }
178174
179175 LLAMA_LOG_INFO (" %s: %10s KV buffer size = %8.2f MiB\n " , __func__, ggml_backend_buffer_name (buf), ggml_backend_buffer_get_size (buf)/1024.0 /1024.0 );
180176
181177 ggml_backend_buffer_clear (buf, 0 );
182- bufs .emplace_back (buf);
178+ ctxs_bufs .emplace_back (std::move (ctx), buf);
183179 }
184180
185181 {
@@ -203,7 +199,7 @@ void llama_kv_cache::clear(bool data) {
203199 }
204200
205201 if (data) {
206- for (auto & buf : bufs ) {
202+ for (auto & [_, buf] : ctxs_bufs ) {
207203 ggml_backend_buffer_clear (buf.get (), 0 );
208204 }
209205 }
@@ -472,8 +468,8 @@ llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) const {
472468
473469std::map<ggml_backend_buffer_type_t , size_t > llama_kv_cache::memory_breakdown () const {
474470 std::map<ggml_backend_buffer_type_t , size_t > ret;
475- for (const ggml_backend_buffer_ptr & buf_ptr : bufs ) {
476- ret[ggml_backend_buffer_get_type (buf_ptr .get ())] += ggml_backend_buffer_get_size (buf_ptr .get ());
471+ for (const auto & [_, buf] : ctxs_bufs ) {
472+ ret[ggml_backend_buffer_get_type (buf .get ())] += ggml_backend_buffer_get_size (buf .get ());
477473 }
478474 return ret;
479475}
@@ -1298,7 +1294,7 @@ void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch
12981294size_t llama_kv_cache::total_size () const {
12991295 size_t size = 0 ;
13001296
1301- for (const auto & buf : bufs ) {
1297+ for (const auto & [_, buf] : ctxs_bufs ) {
13021298 size += ggml_backend_buffer_get_size (buf.get ());
13031299 }
13041300
0 commit comments