44#include  " llama-batch.h" 
55#include  " llama-cparams.h" 
66#include  " llama-model.h" 
7+ #include  " llama-context.h" 
78
89#include  < algorithm> 
910#include  < cassert> 
@@ -367,10 +368,10 @@ void llama_kv_cache_unified::commit() {
367368    pending.ranges .clear ();
368369}
369370
370- bool  llama_kv_cache_unified::update (const  graph_params & params ) {
371+ bool  llama_kv_cache_unified::update (llama_context & lctx ) {
371372    bool  need_reserve = false ;
372373
373-     const  auto  & sched = params. sched ;
374+     const  auto  & sched = lctx. get_sched () ;
374375
375376    if  (has_shift) {
376377        if  (!get_can_shift ()) {
@@ -381,17 +382,17 @@ bool llama_kv_cache_unified::update(const graph_params & params) {
381382
382383        //  apply K-shift if needed
383384        if  (hparams.rope_type  != LLAMA_ROPE_TYPE_NONE) {
384-             ggml_backend_sched_reset (sched);
385+             ggml_backend_sched_reset (sched. get () );
385386
386-             auto  * gf = params .graph_init ();
387+             auto  * gf = lctx .graph_init ();
387388
388-             auto  res = build_graph_shift (params , gf);
389+             auto  res = build_graph_shift (lctx , gf);
389390
390-             ggml_backend_sched_alloc_graph (sched, gf);
391+             ggml_backend_sched_alloc_graph (sched. get () , gf);
391392
392393            res->set_inputs (nullptr );
393394
394-             params .graph_compute (gf);
395+             lctx .graph_compute (gf,  false );
395396
396397            need_reserve = true ;
397398        }
@@ -408,18 +409,18 @@ bool llama_kv_cache_unified::update(const graph_params & params) {
408409    if  (do_defrag) {
409410        LLAMA_LOG_DEBUG (" %s: defragmenting KV cache\n "  , __func__);
410411
411-         if  (defrag_prepare (params. n_max_nodes )) {
412-             ggml_backend_sched_reset (sched);
412+         if  (defrag_prepare (lctx. graph_max_nodes () )) {
413+             ggml_backend_sched_reset (sched. get () );
413414
414-             auto  * gf = params .graph_init ();
415+             auto  * gf = lctx .graph_init ();
415416
416-             auto  res = build_graph_defrag (params , gf);
417+             auto  res = build_graph_defrag (lctx , gf);
417418
418-             ggml_backend_sched_alloc_graph (sched, gf);
419+             ggml_backend_sched_alloc_graph (sched. get () , gf);
419420
420421            res->set_inputs (nullptr );
421422
422-             params .graph_compute (gf);
423+             lctx .graph_compute (gf,  false );
423424
424425            need_reserve = true ;
425426        }
@@ -591,17 +592,17 @@ size_t llama_kv_cache_unified::size_v_bytes() const {
591592}
592593
593594ggml_tensor * llama_kv_cache_unified::build_rope_shift (
594-         const  graph_params & params ,
595-                ggml_context * ctx,
596-                 ggml_tensor * cur,
597-                 ggml_tensor * shift,
598-                 ggml_tensor * factors,
599-                       float    freq_base,
600-                       float    freq_scale,
601-         ggml_backend_buffer * bbuf) const  {
602-     const  auto  & cparams  = params. cparams ;
603-     const  auto  & backends = params. backends ;
604-     const  auto  & sched    = params. sched ;
595+         llama_context & lctx ,
596+          ggml_context * ctx,
597+           ggml_tensor * cur,
598+           ggml_tensor * shift,
599+           ggml_tensor * factors,
600+                 float    freq_base,
601+                 float    freq_scale,
602+   ggml_backend_buffer * bbuf) const  {
603+     const  auto  & cparams  = lctx. get_cparams () ;
604+     const  auto  & backends = lctx. get_backends () ;
605+     const  auto  & sched    = lctx. get_sched () ;
605606
606607    const  auto  & n_ctx_orig = cparams.n_ctx_orig_yarn ;
607608
@@ -622,11 +623,12 @@ ggml_tensor * llama_kv_cache_unified::build_rope_shift(
622623        //  dequantize to f32 -> RoPE -> quantize back
623624        tmp = ggml_cast (ctx, cur, GGML_TYPE_F32);
624625
626+         //  TODO: can we simplify/avoid this?
625627        if  (bbuf) {
626628            for  (const  auto  & backend : backends) {
627629                //  Figure out which backend KV cache belongs to
628630                if  (ggml_backend_supports_buft (backend.get (), ggml_backend_buffer_get_type (bbuf))) {
629-                     ggml_backend_sched_set_tensor_backend (sched, tmp, backend.get ());
631+                     ggml_backend_sched_set_tensor_backend (sched. get () , tmp, backend.get ());
630632                    break ;
631633                }
632634            }
@@ -674,13 +676,13 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
674676}
675677
676678llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift (
677-              const  graph_params & params ,
678-                     ggml_cgraph * gf) const  {
679+         llama_context & lctx ,
680+         ggml_cgraph * gf) const  {
679681    auto  res = std::make_unique<llm_graph_result>();
680682
681-     auto  * ctx = params .get_ctx_compute ();
683+     auto  * ctx = lctx .get_ctx_compute (). get ();
682684
683-     const  auto  & cparams = params. cparams ;
685+     const  auto  & cparams = lctx. get_cparams () ;
684686
685687    const  auto  & n_layer = hparams.n_layer ;
686688
@@ -716,7 +718,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
716718                ggml_row_size (k_l[il]->type , n_embd_k_gqa),
717719                0 );
718720
719-         ggml_tensor * cur = build_rope_shift (params , ctx, k, inp->k_shift , rope_factors, freq_base_l, freq_scale_l, k_l[il]->buffer );
721+         ggml_tensor * cur = build_rope_shift (lctx , ctx, k, inp->k_shift , rope_factors, freq_base_l, freq_scale_l, k_l[il]->buffer );
720722
721723        ggml_build_forward_expand (gf, cur);
722724    }
@@ -727,15 +729,15 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
727729}
728730
729731llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag (
730-              const  graph_params & params ,
731-                     ggml_cgraph * gf) const  {
732+         llama_context & lctx ,
733+           ggml_cgraph * gf) const  {
732734    auto  res = std::make_unique<llm_graph_result>();
733735
734-     auto  * ctx = params .get_ctx_compute ();
736+     auto  * ctx = lctx .get_ctx_compute (). get ();
735737
736738    const  auto  & ids = defrag_info.ids ;
737739
738-     const  auto  & cparams = params. cparams ;
740+     const  auto  & cparams = lctx. get_cparams () ;
739741
740742#if  0 
741743    // CPU defrag
@@ -1725,8 +1727,8 @@ void llama_kv_cache_recurrent::commit() {
17251727    pending.ranges .clear ();
17261728}
17271729
1728- bool  llama_kv_cache_recurrent::update (const  graph_params & params ) {
1729-     GGML_UNUSED (params );
1730+ bool  llama_kv_cache_recurrent::update (llama_context & lctx ) {
1731+     GGML_UNUSED (lctx );
17301732    return  false ;
17311733}
17321734
0 commit comments