1818//
1919// llama_context
2020//
21+ struct llama_context_kv_cache_data {
22+ llama_kv_cache_unified::slot_info_vec_t last_main_model_sinfos;
23+ llama_kv_cache_unified::slot_info_vec_t resized_sinfo_for_force;
24+ const llama_kv_cache_unified::slot_info_vec_t * forced_sinfos = nullptr ;
25+ };
2126
2227llama_context::llama_context (
2328 const llama_model & model,
@@ -106,6 +111,8 @@ llama_context::llama_context(
106111 cparams.op_offload = params.op_offload ;
107112 cparams.kv_unified = params.kv_unified ;
108113
114+ kv_cache_data = new llama_context_kv_cache_data ();
115+
109116 {
110117 const char * LLAMA_SET_ROWS = getenv (" LLAMA_SET_ROWS" );
111118 supports_set_rows = LLAMA_SET_ROWS ? (atoi (LLAMA_SET_ROWS) != 0 ) : supports_set_rows;
@@ -371,6 +378,7 @@ llama_context::llama_context(
371378
372379llama_context::~llama_context () {
373380 ggml_opt_free (opt_ctx);
381+ delete static_cast <llama_context_kv_cache_data *>(kv_cache_data);
374382}
375383
376384void llama_context::synchronize () {
@@ -1017,6 +1025,8 @@ int llama_context::encode(const llama_batch & batch_inp) {
10171025
10181026int llama_context::decode (const llama_batch & batch_inp) {
10191027 GGML_ASSERT ((!batch_inp.token && batch_inp.embd ) || (batch_inp.token && !batch_inp.embd )); // NOLINT
1028+
1029+ auto * kvd = static_cast <llama_context_kv_cache_data *>(kv_cache_data);
10201030 LLAMA_LOG_WARN (" [DEBUG-DECODE-ENTRY] Entering llama_decode. update_mtp_kv=%s, use_mtp_head=%s\n " ,
10211031 batch_inp.update_mtp_kv ? " true" : " false" ,
10221032 batch_inp.use_mtp_head ? " true" : " false"
@@ -1076,10 +1086,31 @@ int llama_context::decode(const llama_batch & batch_inp) {
10761086 // handle any pending defrags/shifts
10771087 kv_self_update (false );
10781088
1079- llama_memory_context_ptr mctx;
1089+ std::unique_ptr<llama_memory_context_i> mctx;
10801090
10811091 while (true ) {
1082- mctx = memory->init_batch (*balloc, cparams.n_ubatch , output_all);
1092+ if (cparams.warmup ) {
1093+ mctx = memory->init_batch (*balloc, cparams.n_ubatch , output_all);
1094+ } else {
1095+ if (kvd->forced_sinfos && !kvd->forced_sinfos ->empty ()) {
1096+ LLAMA_LOG_WARN (" [DEBUG-CACHE-REUSE] Forcing sinfos, bypassing find_slot.\n " );
1097+
1098+ mctx = static_cast <llama_kv_cache_unified *>(memory.get ())->init_batch_with_sinfos (
1099+ *balloc, cparams.n_ubatch , *kvd->forced_sinfos , true
1100+ );
1101+ } else {
1102+ mctx = memory->init_batch (*balloc, cparams.n_ubatch , output_all);
1103+
1104+ if (!batch_inp.use_mtp_head && !batch_inp.update_mtp_kv ) {
1105+ if (mctx && mctx->get_status () == LLAMA_MEMORY_STATUS_SUCCESS) {
1106+ kvd->last_main_model_sinfos = static_cast <llama_kv_cache_unified_context *>(mctx.get ())->get_sinfos ();
1107+ } else {
1108+ kvd->last_main_model_sinfos .clear ();
1109+ }
1110+ }
1111+ }
1112+ }
1113+
10831114 if (!mctx) {
10841115 return -2 ;
10851116 }
@@ -1091,29 +1122,28 @@ int llama_context::decode(const llama_batch & batch_inp) {
10911122 case LLAMA_MEMORY_STATUS_NO_UPDATE:
10921123 {
10931124 LLAMA_LOG_ERROR (" %s: unexpected memory context status: %d\n " , __func__, mctx->get_status ());
1094-
10951125 return -2 ;
10961126 }
10971127 case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
10981128 {
1129+ // if (use_last_main_model_sinfos) {
1130+ // LLAMA_LOG_ERROR("%s: Mismatch between ubatches and sinfos during reuse.\n", __func__);
1131+ // return -1;
1132+ // }
1133+
10991134 if (!did_optimize) {
11001135 did_optimize = true ;
1101-
11021136 if (kv_self_update (true )) {
11031137 LLAMA_LOG_DEBUG (" %s: retrying batch size %d after cache optimization\n " , __func__, balloc->get_n_tokens ());
1104-
11051138 continue ;
11061139 }
11071140 }
1108-
11091141 LLAMA_LOG_WARN (" %s: failed to find a memory slot for batch of size %d\n " , __func__, balloc->get_n_tokens ());
1110-
11111142 return 1 ;
11121143 }
11131144 case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
11141145 {
11151146 LLAMA_LOG_ERROR (" %s: compute failed while preparing batch of size %d\n " , __func__, balloc->get_n_tokens ());
1116-
11171147 return -2 ;
11181148 }
11191149 }
@@ -3073,4 +3103,27 @@ void llama_opt_epoch(
30733103
30743104void llama_set_draft_input_hidden_state (struct llama_context * ctx, const float * hidden_state) {
30753105 ctx->draft_input_hidden_state = hidden_state;
3106+ }
3107+
3108+ bool llama_mtp_prepare_sinfo_for_update (struct llama_context * ctx, size_t n_accepted) {
3109+ auto * kvd = static_cast <llama_context_kv_cache_data *>(ctx->kv_cache_data );
3110+ const auto & last_sinfo = kvd->last_main_model_sinfos ;
3111+
3112+ if (last_sinfo.empty () || last_sinfo[0 ].idxs .empty ()) {
3113+ LLAMA_LOG_ERROR (" %s: The sinfo for the last main call is not available." , __func__);
3114+ return false ;
3115+ }
3116+
3117+ kvd->resized_sinfo_for_force = last_sinfo;
3118+
3119+ kvd->resized_sinfo_for_force [0 ].idxs [0 ].resize (n_accepted);
3120+
3121+ kvd->forced_sinfos = &kvd->resized_sinfo_for_force ;
3122+
3123+ return true ;
3124+ }
3125+
3126+ void llama_mtp_cancel_sinfo_update (struct llama_context * ctx) {
3127+ auto * kvd = static_cast <llama_context_kv_cache_data *>(ctx->kv_cache_data );
3128+ kvd->forced_sinfos = nullptr ;
30763129}
0 commit comments