@@ -179,24 +179,37 @@ llama_context::llama_context(
179179 // init the memory module
180180 // TODO: for now, always create a unified KV cache
181181 if (!hparams.vocab_only ) {
182- kv_self.reset (static_cast <llama_kv_cache_unified *>(model.create_memory ()));
182+ uint32_t kv_size = 0 ;
183+ ggml_type type_k = params.type_k ;
184+ ggml_type type_v = params.type_v ;
183185
184- LLAMA_LOG_DEBUG (" %s: n_ctx = %u\n " , __func__, cparams.n_ctx );
186+ if (!llama_model_is_recurrent (&model)) {
187+ // kv_self.reset(static_cast<llama_kv_cache_unified *>(model.create_memory()));
188+ auto * kv = static_cast <llama_kv_cache_unified *>(model.create_memory ());
185189
186- cparams. n_ctx = GGML_PAD (cparams. n_ctx , kv_self-> get_padding ( cparams) );
190+ LLAMA_LOG_DEBUG ( " %s: n_ctx = %u \n " , __func__, cparams. n_ctx );
187191
188- LLAMA_LOG_DEBUG ( " %s: n_ctx = %u (padded) \n " , __func__, cparams.n_ctx );
192+ cparams. n_ctx = GGML_PAD ( cparams.n_ctx , kv-> get_padding (cparams) );
189193
190- uint32_t kv_size = cparams.n_ctx ;
191- ggml_type type_k = params.type_k ;
192- ggml_type type_v = params.type_v ;
194+ LLAMA_LOG_DEBUG (" %s: n_ctx = %u (padded)\n " , __func__, cparams.n_ctx );
195+
196+ kv_size = cparams.n_ctx ;
197+ type_k = params.type_k ;
198+ type_v = params.type_v ;
199+
200+ kv_self.reset (kv);
201+ } else {
202+ auto * kv = static_cast <llama_kv_cache_recurrent *>(model.create_memory ());
203+
204+ LLAMA_LOG_DEBUG (" %s: n_ctx = %u\n " , __func__, cparams.n_ctx );
193205
194- if (llama_model_is_recurrent (&model)) {
195206 // Mamba needs at least as many KV cells as there are sequences kept at any time
196207 kv_size = std::max ((uint32_t ) 1 , params.n_seq_max );
197208 // it's probably best to keep as much precision as possible for the states
198209 type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
199210 type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
211+
212+ kv_self.reset (kv);
200213 }
201214
202215 GGML_ASSERT (hparams.n_embd_head_k % ggml_blck_size (type_k) == 0 );
@@ -305,7 +318,7 @@ llama_context::llama_context(
305318 int n_nodes_tg = -1 ;
306319
307320 // simulate full KV cache
308- kv_self->n = kv_self-> size ;
321+ kv_self->set_full () ;
309322
310323 cross.v_embd .clear ();
311324
@@ -557,7 +570,9 @@ llm_graph_result_ptr llama_context::build_kv_self_shift(
557570
558571 // GGML_ASSERT(kv_self->size == n_ctx);
559572
560- auto inp = std::make_unique<llm_graph_input_k_shift>(kv_self.get ());
573+ const auto * kv = static_cast <const llama_kv_cache_unified *>(kv_self.get ());
574+
575+ auto inp = std::make_unique<llm_graph_input_k_shift>(kv);
561576
562577 inp->k_shift = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, cparams.n_ctx );
563578 ggml_set_input (inp->k_shift );
@@ -573,16 +588,16 @@ llm_graph_result_ptr llama_context::build_kv_self_shift(
573588 const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base ;
574589 const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale ;
575590
576- ggml_tensor * rope_factors = kv_self ->cbs .get_rope_factors (n_ctx_per_seq (), il);
591+ ggml_tensor * rope_factors = kv ->cbs .get_rope_factors (n_ctx_per_seq (), il);
577592
578593 ggml_tensor * k =
579- ggml_view_3d (ctx0, kv_self ->k_l [il],
580- n_embd_head_k, n_head_kv, kv_self ->size ,
581- ggml_row_size (kv_self ->k_l [il]->type , n_embd_head_k),
582- ggml_row_size (kv_self ->k_l [il]->type , n_embd_k_gqa),
594+ ggml_view_3d (ctx0, kv ->k_l [il],
595+ n_embd_head_k, n_head_kv, kv ->size ,
596+ ggml_row_size (kv ->k_l [il]->type , n_embd_head_k),
597+ ggml_row_size (kv ->k_l [il]->type , n_embd_k_gqa),
583598 0 );
584599
585- ggml_tensor * cur = build_rope_shift (ctx0, k, inp->k_shift , rope_factors, freq_base_l, freq_scale_l, kv_self ->k_l [il]->buffer );
600+ ggml_tensor * cur = build_rope_shift (ctx0, k, inp->k_shift , rope_factors, freq_base_l, freq_scale_l, kv ->k_l [il]->buffer );
586601
587602 ggml_build_forward_expand (gf, cur);
588603 }
@@ -597,9 +612,11 @@ llm_graph_result_ptr llama_context::build_kv_self_defrag(
597612 ggml_cgraph * gf) const {
598613 auto res = std::make_unique<llm_graph_result>();
599614
615+ auto * kv = static_cast <llama_kv_cache_unified *>(kv_self.get ());
616+
600617 const auto & hparams = model.hparams ;
601618
602- const auto & ids = kv_self ->defrag_info .ids ;
619+ const auto & ids = kv ->defrag_info .ids ;
603620
604621#if 0
605622 // CPU defrag
@@ -689,40 +706,40 @@ llm_graph_result_ptr llama_context::build_kv_self_defrag(
689706 const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa (il);
690707 const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa (il);
691708
692- ggml_tensor * view_k_src = ggml_view_2d (ctx0, kv_self ->k_l [il],
709+ ggml_tensor * view_k_src = ggml_view_2d (ctx0, kv ->k_l [il],
693710 n_embd_k_gqa, nm,
694- ggml_row_size (kv_self ->k_l [il]->type , n_embd_k_gqa),
695- ggml_row_size (kv_self ->k_l [il]->type , n_embd_k_gqa*i));
711+ ggml_row_size (kv ->k_l [il]->type , n_embd_k_gqa),
712+ ggml_row_size (kv ->k_l [il]->type , n_embd_k_gqa*i));
696713
697- ggml_tensor * view_k_dst = ggml_view_2d (ctx0, kv_self ->k_l [il],
714+ ggml_tensor * view_k_dst = ggml_view_2d (ctx0, kv ->k_l [il],
698715 n_embd_k_gqa, nm,
699- ggml_row_size (kv_self ->k_l [il]->type , n_embd_k_gqa),
700- ggml_row_size (kv_self ->k_l [il]->type , n_embd_k_gqa*id));
716+ ggml_row_size (kv ->k_l [il]->type , n_embd_k_gqa),
717+ ggml_row_size (kv ->k_l [il]->type , n_embd_k_gqa*id));
701718
702719 ggml_tensor * view_v_src;
703720 ggml_tensor * view_v_dst;
704721
705722 if (cparams.flash_attn ) {
706723 // NOTE: the V cache is not transposed when using flash attention
707- view_v_src = ggml_view_2d (ctx0, kv_self ->v_l [il],
724+ view_v_src = ggml_view_2d (ctx0, kv ->v_l [il],
708725 n_embd_v_gqa, nm,
709- ggml_row_size (kv_self ->v_l [il]->type , n_embd_v_gqa),
710- ggml_row_size (kv_self ->v_l [il]->type , n_embd_v_gqa*i));
726+ ggml_row_size (kv ->v_l [il]->type , n_embd_v_gqa),
727+ ggml_row_size (kv ->v_l [il]->type , n_embd_v_gqa*i));
711728
712- view_v_dst = ggml_view_2d (ctx0, kv_self ->v_l [il],
729+ view_v_dst = ggml_view_2d (ctx0, kv ->v_l [il],
713730 n_embd_v_gqa, nm,
714- ggml_row_size (kv_self ->v_l [il]->type , n_embd_v_gqa),
715- ggml_row_size (kv_self ->v_l [il]->type , n_embd_v_gqa*id));
731+ ggml_row_size (kv ->v_l [il]->type , n_embd_v_gqa),
732+ ggml_row_size (kv ->v_l [il]->type , n_embd_v_gqa*id));
716733 } else {
717- view_v_src = ggml_view_2d (ctx0, kv_self ->v_l [il],
734+ view_v_src = ggml_view_2d (ctx0, kv ->v_l [il],
718735 nm, n_embd_v_gqa,
719- ggml_row_size (kv_self ->v_l [il]->type , kv_self ->size ),
720- ggml_row_size (kv_self ->v_l [il]->type , i));
736+ ggml_row_size (kv ->v_l [il]->type , kv ->size ),
737+ ggml_row_size (kv ->v_l [il]->type , i));
721738
722- view_v_dst = ggml_view_2d (ctx0, kv_self ->v_l [il],
739+ view_v_dst = ggml_view_2d (ctx0, kv ->v_l [il],
723740 nm, n_embd_v_gqa,
724- ggml_row_size (kv_self ->v_l [il]->type , kv_self ->size ),
725- ggml_row_size (kv_self ->v_l [il]->type , id));
741+ ggml_row_size (kv ->v_l [il]->type , kv ->size ),
742+ ggml_row_size (kv ->v_l [il]->type , id));
726743 }
727744
728745 ggml_build_forward_expand (gf, ggml_cpy (ctx0, view_k_src, view_k_dst));
@@ -739,13 +756,11 @@ llm_graph_result_ptr llama_context::build_kv_self_defrag(
739756}
740757
741758void llama_context::kv_self_update () {
742- auto & kv = kv_self;
743-
744759 bool need_reserve = false ;
745760
746- if (kv-> has_shift ) {
747- if (!kv ->get_can_shift ()) {
748- GGML_ABORT (" The current context does not support K-shift" );
761+ if (kv_self-> get_has_shift () ) {
762+ if (!kv_self ->get_can_shift ()) {
763+ GGML_ABORT (" The current KV cache / model configuration does not support K-shift" );
749764 }
750765
751766 LLAMA_LOG_DEBUG (" %s: applying K-shift\n " , __func__);
@@ -768,6 +783,8 @@ void llama_context::kv_self_update() {
768783 }
769784
770785 {
786+ auto * kv = static_cast <llama_kv_cache_unified *>(kv_self.get ());
787+
771788 kv->has_shift = false ;
772789
773790 for (uint32_t i = 0 ; i < kv->size ; ++i) {
@@ -777,9 +794,11 @@ void llama_context::kv_self_update() {
777794 }
778795
779796 // defragment the KV cache if needed
780- if (kv-> do_defrag ) {
797+ if (kv_self-> get_do_defrag () ) {
781798 LLAMA_LOG_DEBUG (" %s: defragmenting KV cache\n " , __func__);
782799
800+ auto * kv = static_cast <llama_kv_cache_unified *>(kv_self.get ());
801+
783802 if (kv->defrag_prepare (graph_max_nodes ())) {
784803 ggml_backend_sched_reset (sched.get ());
785804
@@ -808,7 +827,7 @@ void llama_context::kv_self_update() {
808827 uint32_t n_tokens = std::min (cparams.n_ctx , cparams.n_ubatch );
809828
810829 // simulate full KV cache
811- kv_self->n = kv_self-> size ;
830+ kv_self->set_full () ;
812831
813832 llama_token token = model.vocab .token_bos (); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
814833 llama_ubatch ubatch = { true , n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr , nullptr , nullptr , nullptr , nullptr };
@@ -1028,8 +1047,8 @@ int llama_context::encode(llama_batch & inp_batch) {
10281047 }
10291048
10301049 // temporary allocate memory for the input batch if needed
1031- // TODO: this is incorrect for multiple sequences because pos_max () is the maximum across all sequences
1032- llama_batch_allocr batch_allocr (inp_batch, inp_batch.pos ? -1 : kv_self->pos_max () + 1 );
1050+ // TODO: this is incorrect for multiple sequences because get_pos_max () is the maximum across all sequences
1051+ llama_batch_allocr batch_allocr (inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max () + 1 );
10331052
10341053 const llama_batch & batch = batch_allocr.batch ;
10351054 const int32_t n_tokens = batch.n_tokens ;
@@ -1193,8 +1212,8 @@ int llama_context::decode(llama_batch & inp_batch) {
11931212 }
11941213
11951214 // temporary allocate memory for the input batch if needed
1196- // TODO: this is incorrect for multiple sequences because pos_max () is the maximum across all sequences
1197- llama_batch_allocr batch_allocr (inp_batch, inp_batch.pos ? -1 : kv_self->pos_max () + 1 );
1215+ // TODO: this is incorrect for multiple sequences because get_pos_max () is the maximum across all sequences
1216+ llama_batch_allocr batch_allocr (inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max () + 1 );
11981217
11991218 const llama_batch & batch = batch_allocr.batch ;
12001219
@@ -1249,8 +1268,10 @@ int llama_context::decode(llama_batch & inp_batch) {
12491268
12501269 const bool logits_all = n_outputs_all == n_tokens_all;
12511270
1271+ const bool is_recurrent = llama_model_is_recurrent (&model);
1272+
12521273 sbatch.from_batch (batch, n_embd,
1253- /* simple_split */ !kv_self-> recurrent ,
1274+ /* simple_split */ !is_recurrent ,
12541275 /* logits_all */ logits_all);
12551276
12561277 // reserve output buffer
@@ -1269,7 +1290,7 @@ int llama_context::decode(llama_batch & inp_batch) {
12691290
12701291 const auto & n_ubatch = cparams.n_ubatch ;
12711292
1272- if (kv_self-> recurrent ) {
1293+ if (is_recurrent ) {
12731294 if (embd_pooled) {
12741295 // Pooled embeddings cannot be split across ubatches (yet)
12751296 ubatch = sbatch.split_seq (cparams.n_ubatch );
@@ -1307,17 +1328,19 @@ int llama_context::decode(llama_batch & inp_batch) {
13071328 return 1 ;
13081329 }
13091330
1310- if (!kv_self->recurrent ) {
1331+ if (!is_recurrent) {
1332+ auto * kv = static_cast <llama_kv_cache_unified *>(kv_self.get ());
1333+
13111334 // a heuristic, to avoid attending the full cache if it is not yet utilized
13121335 // after enough generations, the benefit from this heuristic disappears
13131336 // if we start defragmenting the cache, the benefit from this will be more important
1314- const uint32_t pad = kv_self->get_padding (cparams);
1315- kv_self->n = std::min (kv_self->size , std::max (pad, GGML_PAD (kv_self->cell_max (), pad)));
1337+ const uint32_t pad = kv->get_padding (cparams);
1338+ kv->n = std::min (kv->size , std::max (pad, GGML_PAD (kv->cell_max (), pad)));
1339+
1340+ // printf("kv.n = %5d, kv.used = %5d, kv.head = %5d\n", kv->n, kv->used, kv->head);
13161341 }
13171342 }
13181343
1319- // printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self->n, kv_self->used, kv_self->head);
1320-
13211344 ggml_backend_sched_reset (sched.get ());
13221345 ggml_backend_sched_set_eval_callback (sched.get (), cparams.cb_eval , cparams.cb_eval_user_data );
13231346
@@ -1457,10 +1480,12 @@ int llama_context::decode(llama_batch & inp_batch) {
14571480 // synchronize();
14581481
14591482 // decide if we need to defrag the kv cache
1460- if (cparams.causal_attn && cparams.defrag_thold > 0 .0f ) {
1483+ if (!llama_model_is_recurrent (&model) && cparams.causal_attn && cparams.defrag_thold > 0 .0f ) {
1484+ auto * kv = static_cast <llama_kv_cache_unified *>(kv_self.get ());
1485+
14611486 // - do not defrag small contexts (i.e. < 2048 tokens)
14621487 // - count the padding towards the number of used tokens
1463- const float fragmentation = kv_self ->n >= 2048 ? std::max (0 .0f , 1 .0f - float (kv_self ->used + kv_self ->get_padding (cparams))/float (kv_self ->n )) : 0 .0f ;
1488+ const float fragmentation = kv ->n >= 2048 ? std::max (0 .0f , 1 .0f - float (kv ->used + kv ->get_padding (cparams))/float (kv ->n )) : 0 .0f ;
14641489
14651490 // queue defragmentation for next llama_kv_cache_update
14661491 if (fragmentation > cparams.defrag_thold ) {
0 commit comments