@@ -177,65 +177,35 @@ llama_context::llama_context(
177177 }
178178
179179 // init the memory module
180- // TODO: for now, always create a unified KV cache
181180 if (!hparams.vocab_only ) {
182- uint32_t kv_size = 0 ;
183- ggml_type type_k = params.type_k ;
184- ggml_type type_v = params.type_v ;
181+ LLAMA_LOG_DEBUG (" %s: n_ctx = %u\n " , __func__, cparams.n_ctx );
185182
186183 if (!llama_model_is_recurrent (&model)) {
187- LLAMA_LOG_DEBUG (" %s: n_ctx = %u\n " , __func__, cparams.n_ctx );
188-
189184 cparams.n_ctx = GGML_PAD (cparams.n_ctx , llama_kv_cache_unified::get_padding (cparams));
190185
191186 LLAMA_LOG_DEBUG (" %s: n_ctx = %u (padded)\n " , __func__, cparams.n_ctx );
192187
193- kv_size = cparams.n_ctx ;
194- type_k = params.type_k ;
195- type_v = params.type_v ;
196-
197188 llama_memory_params params_mem = {
198- /* .type_k =*/ type_k,
199- /* .type_v =*/ type_v,
189+ /* .type_k =*/ params. type_k ,
190+ /* .type_v =*/ params. type_v ,
200191 /* .v_trans =*/ !cparams.flash_attn ,
201192 /* .offload_kqv =*/ cparams.offload_kqv ,
202- /* .kv_size =*/ kv_size ,
193+ /* .kv_size =*/ cparams. n_ctx ,
203194 };
204195
205- auto * kv = static_cast <llama_kv_cache_unified *>(model.create_memory (params_mem));
206-
207- kv_self.reset (kv);
196+ memory.reset (model.create_memory (params_mem));
208197 } else {
209- // Mamba needs at least as many KV cells as there are sequences kept at any time
210- kv_size = std::max ((uint32_t ) 1 , params.n_seq_max );
211- // it's probably best to keep as much precision as possible for the states
212- type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
213- type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
214-
215198 llama_memory_params params_mem = {
216- /* .type_k =*/ type_k,
217- /* .type_v =*/ type_v,
199+ /* .type_k =*/ GGML_TYPE_F32, // required by ggml_ssm_conv for Mamba's conv_states
200+ /* .type_v =*/ GGML_TYPE_F32, // required by ggml_ssm_scan for Mamba's ssm_states
218201 /* .v_trans =*/ false , // unused
219- /* .offload_kqv =*/ params .offload_kqv ,
220- /* .kv_size =*/ kv_size,
202+ /* .offload_kqv =*/ cparams .offload_kqv ,
203+ /* .kv_size =*/ std::max (( uint32_t ) 1 , params. n_seq_max ), // Mamba needs at least as many KV cells as there are sequences kept at any time
221204 };
222205
223- auto * kv = static_cast <llama_kv_cache_recurrent *>(model.create_memory (params_mem));
224-
225- LLAMA_LOG_DEBUG (" %s: n_ctx = %u\n " , __func__, cparams.n_ctx );
226-
227- kv_self.reset (kv);
206+ memory.reset (model.create_memory (params_mem));
228207 }
229208
230- {
231- const size_t memory_size_k = kv_self->size_k_bytes ();
232- const size_t memory_size_v = kv_self->size_v_bytes ();
233-
234- LLAMA_LOG_INFO (" %s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n " , __func__,
235- (float )(memory_size_k + memory_size_v) / (1024 .0f * 1024 .0f ),
236- ggml_type_name (type_k), (float )memory_size_k / (1024 .0f * 1024 .0f ),
237- ggml_type_name (type_v), (float )memory_size_v / (1024 .0f * 1024 .0f ));
238- }
239209 }
240210
241211 // init backends
@@ -326,6 +296,8 @@ llama_context::llama_context(
326296 int n_nodes_tg = -1 ;
327297
328298 // simulate full KV cache
299+ llama_kv_cache * kv_self = static_cast <llama_kv_cache *>(memory.get ());
300+
329301 kv_self->set_full ();
330302
331303 cross.v_embd .clear ();
@@ -477,11 +449,13 @@ uint32_t llama_context::n_threads_batch() const {
477449}
478450
479451llama_kv_cache * llama_context::get_kv_self () {
480- return kv_self.get ();
452+ llama_kv_cache * kv_self = static_cast <llama_kv_cache *>(memory.get ());
453+ return kv_self;
481454}
482455
483456const llama_kv_cache * llama_context::get_kv_self () const {
484- return kv_self.get ();
457+ llama_kv_cache * kv_self = static_cast <llama_kv_cache *>(memory.get ());
458+ return kv_self;
485459}
486460
487461ggml_tensor * llama_context::build_rope_shift (
@@ -578,7 +552,7 @@ llm_graph_result_ptr llama_context::build_kv_self_shift(
578552
579553 // GGML_ASSERT(kv_self->size == n_ctx);
580554
581- const auto * kv = static_cast <const llama_kv_cache_unified *>(kv_self .get ());
555+ const auto * kv = static_cast <const llama_kv_cache_unified *>(memory .get ());
582556
583557 auto inp = std::make_unique<llm_graph_input_k_shift>(kv);
584558
@@ -620,7 +594,7 @@ llm_graph_result_ptr llama_context::build_kv_self_defrag(
620594 ggml_cgraph * gf) const {
621595 auto res = std::make_unique<llm_graph_result>();
622596
623- auto * kv = static_cast <llama_kv_cache_unified *>(kv_self .get ());
597+ auto * kv = static_cast <llama_kv_cache_unified *>(memory .get ());
624598
625599 const auto & hparams = model.hparams ;
626600
@@ -766,6 +740,8 @@ llm_graph_result_ptr llama_context::build_kv_self_defrag(
766740void llama_context::kv_self_update () {
767741 bool need_reserve = false ;
768742
743+ llama_kv_cache * kv_self = static_cast <llama_kv_cache *>(memory.get ());
744+
769745 if (kv_self->get_has_shift ()) {
770746 if (!kv_self->get_can_shift ()) {
771747 GGML_ABORT (" The current KV cache / model configuration does not support K-shift" );
@@ -791,7 +767,7 @@ void llama_context::kv_self_update() {
791767 }
792768
793769 {
794- auto * kv = static_cast <llama_kv_cache_unified *>(kv_self. get () );
770+ auto * kv = static_cast <llama_kv_cache_unified *>(kv_self);
795771
796772 kv->has_shift = false ;
797773
@@ -805,7 +781,7 @@ void llama_context::kv_self_update() {
805781 if (kv_self->get_do_defrag ()) {
806782 LLAMA_LOG_DEBUG (" %s: defragmenting KV cache\n " , __func__);
807783
808- auto * kv = static_cast <llama_kv_cache_unified *>(kv_self. get () );
784+ auto * kv = static_cast <llama_kv_cache_unified *>(kv_self);
809785
810786 if (kv->defrag_prepare (graph_max_nodes ())) {
811787 ggml_backend_sched_reset (sched.get ());
@@ -1054,6 +1030,8 @@ int llama_context::encode(llama_batch & inp_batch) {
10541030 return -1 ;
10551031 }
10561032
1033+ llama_kv_cache * kv_self = static_cast <llama_kv_cache *>(memory.get ());
1034+
10571035 // temporary allocate memory for the input batch if needed
10581036 // TODO: this is incorrect for multiple sequences because get_pos_max() is the maximum across all sequences
10591037 llama_batch_allocr batch_allocr (inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max () + 1 );
@@ -1219,6 +1197,8 @@ int llama_context::decode(llama_batch & inp_batch) {
12191197 return -1 ;
12201198 }
12211199
1200+ llama_kv_cache * kv_self = static_cast <llama_kv_cache *>(memory.get ());
1201+
12221202 // temporary allocate memory for the input batch if needed
12231203 // TODO: this is incorrect for multiple sequences because get_pos_max() is the maximum across all sequences
12241204 llama_batch_allocr batch_allocr (inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max () + 1 );
@@ -1233,7 +1213,7 @@ int llama_context::decode(llama_batch & inp_batch) {
12331213 const int64_t n_tokens_all = batch.n_tokens ;
12341214 const int64_t n_embd = hparams.n_embd ;
12351215
1236- llama_kv_cache_guard kv_guard (kv_self. get () );
1216+ llama_kv_cache_guard kv_guard (kv_self);
12371217
12381218 GGML_ASSERT ((!batch.token && batch.embd ) || (batch.token && !batch.embd )); // NOLINT
12391219
@@ -1337,7 +1317,7 @@ int llama_context::decode(llama_batch & inp_batch) {
13371317 }
13381318
13391319 if (!is_recurrent) {
1340- auto * kv = static_cast <llama_kv_cache_unified *>(kv_self. get () );
1320+ auto * kv = static_cast <llama_kv_cache_unified *>(kv_self);
13411321
13421322 // a heuristic, to avoid attending the full cache if it is not yet utilized
13431323 // after enough generations, the benefit from this heuristic disappears
@@ -1489,7 +1469,7 @@ int llama_context::decode(llama_batch & inp_batch) {
14891469
14901470 // decide if we need to defrag the kv cache
14911471 if (!llama_model_is_recurrent (&model) && cparams.causal_attn && cparams.defrag_thold > 0 .0f ) {
1492- auto * kv = static_cast <llama_kv_cache_unified *>(kv_self. get () );
1472+ auto * kv = static_cast <llama_kv_cache_unified *>(kv_self);
14931473
14941474 // - do not defrag small contexts (i.e. < 2048 tokens)
14951475 // - count the padding towards the number of used tokens
@@ -1662,7 +1642,7 @@ llm_graph_result_ptr llama_context::graph_build(
16621642 /* .backend_cpu =*/ backend_cpu,
16631643 /* .cvec =*/ &cvec,
16641644 /* .loras =*/ &loras,
1665- /* .memory =*/ kv_self .get (),
1645+ /* .memory =*/ memory .get (),
16661646 /* .cross =*/ &cross,
16671647 /* .n_outputs =*/ n_outputs,
16681648 /* .cb =*/ graph_get_cb (),
@@ -2121,6 +2101,8 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
21212101 }
21222102
21232103 LLAMA_LOG_DEBUG (" %s: - writing KV self\n " , __func__);
2104+ llama_kv_cache * kv_self = static_cast <llama_kv_cache *>(memory.get ());
2105+
21242106 kv_self->state_write (io);
21252107
21262108 return io.n_bytes ();
@@ -2205,6 +2187,8 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
22052187 }
22062188
22072189 LLAMA_LOG_DEBUG (" %s: - reading KV self\n " , __func__);
2190+ llama_kv_cache * kv_self = static_cast <llama_kv_cache *>(memory.get ());
2191+
22082192 kv_self->state_read (io);
22092193
22102194 return io.n_bytes ();
@@ -2213,6 +2197,8 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
22132197size_t llama_context::state_seq_write_data (llama_io_write_i & io, llama_seq_id seq_id) {
22142198 GGML_UNUSED (seq_id);
22152199
2200+ llama_kv_cache * kv_self = static_cast <llama_kv_cache *>(memory.get ());
2201+
22162202 kv_self->state_write (io, seq_id);
22172203
22182204 return io.n_bytes ();
@@ -2221,6 +2207,8 @@ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id s
22212207size_t llama_context::state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id) {
22222208 GGML_UNUSED (seq_id);
22232209
2210+ llama_kv_cache * kv_self = static_cast <llama_kv_cache *>(memory.get ());
2211+
22242212 kv_self->state_read (io, seq_id);
22252213
22262214 return io.n_bytes ();
0 commit comments