44#include  " llama-batch.h" 
55#include  " llama-cparams.h" 
66#include  " llama-model.h" 
7+ #include  " llama-context.h" 
78
89#include  < algorithm> 
910#include  < cassert> 
@@ -367,10 +368,10 @@ void llama_kv_cache_unified::commit() {
367368    pending.ranges .clear ();
368369}
369370
370- bool  llama_kv_cache_unified::update (const  graph_params & params ) {
371+ bool  llama_kv_cache_unified::update (llama_context & lctx ) {
371372    bool  need_reserve = false ;
372373
373-     const  auto  & sched = params. sched ;
374+     const  auto  & sched = lctx. get_sched () ;
374375
375376    if  (has_shift) {
376377        if  (!get_can_shift ()) {
@@ -381,17 +382,17 @@ bool llama_kv_cache_unified::update(const graph_params & params) {
381382
382383        //  apply K-shift if needed
383384        if  (hparams.rope_type  != LLAMA_ROPE_TYPE_NONE) {
384-             ggml_backend_sched_reset (sched);
385+             ggml_backend_sched_reset (sched. get () );
385386
386-             auto  * gf = params .graph_init ();
387+             auto  * gf = lctx .graph_init ();
387388
388-             auto  res = build_graph_shift (params , gf);
389+             auto  res = build_graph_shift (lctx , gf);
389390
390-             ggml_backend_sched_alloc_graph (sched, gf);
391+             ggml_backend_sched_alloc_graph (sched. get () , gf);
391392
392393            res->set_inputs (nullptr );
393394
394-             params .graph_compute (gf);
395+             lctx .graph_compute (gf,  false );
395396
396397            need_reserve = true ;
397398        }
@@ -408,18 +409,18 @@ bool llama_kv_cache_unified::update(const graph_params & params) {
408409    if  (do_defrag) {
409410        LLAMA_LOG_DEBUG (" %s: defragmenting KV cache\n "  , __func__);
410411
411-         if  (defrag_prepare (params. n_max_nodes )) {
412-             ggml_backend_sched_reset (sched);
412+         if  (defrag_prepare (lctx. graph_max_nodes () )) {
413+             ggml_backend_sched_reset (sched. get () );
413414
414-             auto  * gf = params .graph_init ();
415+             auto  * gf = lctx .graph_init ();
415416
416-             auto  res = build_graph_defrag (params , gf);
417+             auto  res = build_graph_defrag (lctx , gf);
417418
418-             ggml_backend_sched_alloc_graph (sched, gf);
419+             ggml_backend_sched_alloc_graph (sched. get () , gf);
419420
420421            res->set_inputs (nullptr );
421422
422-             params .graph_compute (gf);
423+             lctx .graph_compute (gf,  false );
423424
424425            need_reserve = true ;
425426        }
@@ -591,17 +592,17 @@ size_t llama_kv_cache_unified::size_v_bytes() const {
591592}
592593
593594ggml_tensor * llama_kv_cache_unified::build_rope_shift (
594-         const  graph_params & params ,
595-                ggml_context * ctx,
596-                 ggml_tensor * cur,
597-                 ggml_tensor * shift,
598-                 ggml_tensor * factors,
599-                       float    freq_base,
600-                       float    freq_scale,
601-         ggml_backend_buffer * bbuf) const  {
602-     const  auto  & cparams  = params. cparams ;
603-     const  auto  & backends = params. backends ;
604-     const  auto  & sched    = params. sched ;
595+         llama_context & lctx ,
596+          ggml_context * ctx,
597+           ggml_tensor * cur,
598+           ggml_tensor * shift,
599+           ggml_tensor * factors,
600+                 float    freq_base,
601+                 float    freq_scale,
602+   ggml_backend_buffer * bbuf) const  {
603+     const  auto  & cparams  = lctx. get_cparams () ;
604+     const  auto  & backends = lctx. get_backends () ;
605+     const  auto  & sched    = lctx. get_sched () ;
605606
606607    const  auto  & n_ctx_orig = cparams.n_ctx_orig_yarn ;
607608
@@ -626,7 +627,7 @@ ggml_tensor * llama_kv_cache_unified::build_rope_shift(
626627            for  (const  auto  & backend : backends) {
627628                //  Figure out which backend KV cache belongs to
628629                if  (ggml_backend_supports_buft (backend.get (), ggml_backend_buffer_get_type (bbuf))) {
629-                     ggml_backend_sched_set_tensor_backend (sched, tmp, backend.get ());
630+                     ggml_backend_sched_set_tensor_backend (sched. get () , tmp, backend.get ());
630631                    break ;
631632                }
632633            }
@@ -674,13 +675,13 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
674675}
675676
676677llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift (
677-              const  graph_params & params ,
678-                     ggml_cgraph * gf) const  {
678+         llama_context & lctx ,
679+         ggml_cgraph * gf) const  {
679680    auto  res = std::make_unique<llm_graph_result>();
680681
681-     auto  * ctx = params .get_ctx_compute ();
682+     auto  * ctx = lctx .get_ctx_compute (). get ();
682683
683-     const  auto  & cparams = params. cparams ;
684+     const  auto  & cparams = lctx. get_cparams () ;
684685
685686    const  auto  & n_layer = hparams.n_layer ;
686687
@@ -716,7 +717,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
716717                ggml_row_size (k_l[il]->type , n_embd_k_gqa),
717718                0 );
718719
719-         ggml_tensor * cur = build_rope_shift (params , ctx, k, inp->k_shift , rope_factors, freq_base_l, freq_scale_l, k_l[il]->buffer );
720+         ggml_tensor * cur = build_rope_shift (lctx , ctx, k, inp->k_shift , rope_factors, freq_base_l, freq_scale_l, k_l[il]->buffer );
720721
721722        ggml_build_forward_expand (gf, cur);
722723    }
@@ -727,15 +728,15 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
727728}
728729
729730llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag (
730-              const  graph_params & params ,
731-                     ggml_cgraph * gf) const  {
731+         llama_context & lctx ,
732+           ggml_cgraph * gf) const  {
732733    auto  res = std::make_unique<llm_graph_result>();
733734
734-     auto  * ctx = params .get_ctx_compute ();
735+     auto  * ctx = lctx .get_ctx_compute (). get ();
735736
736737    const  auto  & ids = defrag_info.ids ;
737738
738-     const  auto  & cparams = params. cparams ;
739+     const  auto  & cparams = lctx. get_cparams () ;
739740
740741#if  0 
741742    // CPU defrag
@@ -1725,8 +1726,8 @@ void llama_kv_cache_recurrent::commit() {
17251726    pending.ranges .clear ();
17261727}
17271728
1728- bool  llama_kv_cache_recurrent::update (const  graph_params & params ) {
1729-     GGML_UNUSED (params );
1729+ bool  llama_kv_cache_recurrent::update (llama_context & lctx ) {
1730+     GGML_UNUSED (lctx );
17301731    return  false ;
17311732}
17321733
0 commit comments