@@ -750,7 +750,7 @@ static double calculate_vector_sum(const float* vec, size_t size) {
750750}
751751
752752llm_graph_result * llama_context::process_ubatch (const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret,
753- bool do_mtp_kv_update, bool use_mtp_head, bool is_mtp_prompt_warmup ) {
753+ const llama_mtp_params & mtp_params ) {
754754 if (mctx && !mctx->apply ()) {
755755 LLAMA_LOG_ERROR (" %s: failed to apply memory context\n " , __func__);
756756 ret = GGML_STATUS_FAILED;
@@ -762,7 +762,7 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
762762
763763 // the new graph parameters
764764 // in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters
765- const auto gparams = graph_params (res, ubatch, mctx, gtype, do_mtp_kv_update, use_mtp_head );
765+ const auto gparams = graph_params (res, ubatch, mctx, gtype, mtp_params );
766766
767767 if (!graph_reuse_disable && res->can_reuse (gparams)) {
768768 // LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__);
@@ -793,22 +793,22 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
793793 }
794794 }
795795
796- if (do_mtp_kv_update || (use_mtp_head && !do_mtp_kv_update) ) { // If it is any MTP operation
796+ if (mtp_params. op_type != MTP_OP_NONE ) { // If it is any MTP operation
797797 const char * target_tensor_name = " result_embd_pooled" ;
798798 ggml_tensor* hidden_states_input = ggml_get_tensor (res->get_ctx (), target_tensor_name);
799799
800800 const float * source_hidden_state = nullptr ;
801- if (is_mtp_prompt_warmup || (do_mtp_kv_update && !is_mtp_prompt_warmup) ) {
801+ if (mtp_params. op_type == MTP_OP_WARMUP || mtp_params. op_type == MTP_OP_UPDATE_ACCEPTED ) {
802802 source_hidden_state = this ->embd ;
803803 } else {
804804 source_hidden_state = this ->draft_input_hidden_state ;
805805 }
806806
807807 if (source_hidden_state != nullptr && hidden_states_input != nullptr ) {
808808 const size_t n_embd = this ->model .hparams .n_embd ;
809- const size_t n_tokens_for_sum = (do_mtp_kv_update && ubatch.n_tokens > 2 ) ? ubatch.n_tokens : 1 ;
809+ const size_t n_tokens_for_sum = (mtp_params. op_type == MTP_OP_UPDATE_ACCEPTED && ubatch.n_tokens > 2 ) ? ubatch.n_tokens : 1 ;
810810 double input_sum = calculate_vector_sum (source_hidden_state, n_tokens_for_sum * n_embd);
811- const char * op_type = (do_mtp_kv_update ) ? " MTP_UPDATE" : " DRAFT_GEN" ;
811+ const char * op_type = (mtp_params. op_type == MTP_OP_UPDATE_ACCEPTED ) ? " MTP_UPDATE" : " DRAFT_GEN" ;
812812
813813 LLAMA_LOG_WARN (" [MTP-INPUT-CHECK] Operation: %s | Input Checksum: %e\n " , op_type, input_sum);
814814
@@ -833,20 +833,20 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
833833 const int64_t t_exec_start_us = ggml_time_us ();
834834 const auto status = graph_compute (res->get_gf (), ubatch.n_tokens > 1 );
835835 const int64_t t_exec_end_us = ggml_time_us ();
836- LLAMA_LOG_INFO (
837- " [PERF] Graph compute time: %.2f ms (ubatch_size: %u, MTP path: %s)\n " ,
838- (t_exec_end_us - t_exec_start_us) / 1000.0 ,
839- ubatch.n_tokens ,
840- do_mtp_kv_update ? " yes" : " no"
841- );
836+ // LLAMA_LOG_INFO(
837+ // "[PERF] Graph compute time: %.2f ms (ubatch_size: %u, MTP path: %s)\n",
838+ // (t_exec_end_us - t_exec_start_us) / 1000.0,
839+ // ubatch.n_tokens,
840+ // do_mtp_kv_update ? "yes" : "no"
841+ // );
842842 if (status != GGML_STATUS_SUCCESS) {
843843 LLAMA_LOG_ERROR (" %s: failed to compute graph, compute status: %d\n " , __func__, status);
844844 ret = status;
845845 return nullptr ;
846846 }
847847
848848 ret = GGML_STATUS_SUCCESS;
849- if (do_mtp_kv_update || use_mtp_head ) {
849+ if (mtp_params. op_type == MTP_OP_UPDATE_ACCEPTED ) {
850850 ggml_tensor * sum_tensor = ggml_get_tensor (res->get_ctx (), " mtp_input_sum" );
851851 if (sum_tensor) {
852852 LLAMA_LOG_WARN (" [DEBUG-SUM] MTP input sum node successfully created.\n " );
@@ -912,7 +912,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
912912 cparams.causal_attn = false ;
913913
914914 ggml_status status;
915- const auto * res = process_ubatch (ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr , status, false , false , false );
915+ const auto * res = process_ubatch (ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr , status, { MTP_OP_NONE } );
916916
917917 cparams.causal_attn = causal_attn_org;
918918
@@ -1027,10 +1027,10 @@ int llama_context::decode(const llama_batch & batch_inp) {
10271027 GGML_ASSERT ((!batch_inp.token && batch_inp.embd ) || (batch_inp.token && !batch_inp.embd )); // NOLINT
10281028
10291029 auto * kvd = static_cast <llama_context_kv_cache_data *>(kv_cache_data);
1030- LLAMA_LOG_WARN (" [DEBUG-DECODE-ENTRY] Entering llama_decode. update_mtp_kv=%s, use_mtp_head=%s\n " ,
1031- batch_inp.update_mtp_kv ? " true" : " false" ,
1032- batch_inp.use_mtp_head ? " true" : " false"
1033- );
1030+ // LLAMA_LOG_WARN("[DEBUG-DECODE-ENTRY] Entering llama_decode. update_mtp_kv=%s, use_mtp_head=%s\n",
1031+ // batch_inp.update_mtp_kv ? "true" : "false",
1032+ // batch_inp.use_mtp_head ? "true" : "false"
1033+ // );
10341034
10351035 if (!memory) {
10361036 LLAMA_LOG_DEBUG (" %s: cannot decode batches with this context (calling encode() instead)\n " , __func__);
@@ -1101,7 +1101,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
11011101 } else {
11021102 mctx = memory->init_batch (*balloc, cparams.n_ubatch , output_all);
11031103
1104- if (! batch_inp.use_mtp_head && !batch_inp. update_mtp_kv ) {
1104+ if (batch_inp.mtp_params . op_type == MTP_OP_NONE ) {
11051105 if (mctx && mctx->get_status () == LLAMA_MEMORY_STATUS_SUCCESS) {
11061106 kvd->last_main_model_sinfos = static_cast <llama_kv_cache_unified_context *>(mctx.get ())->get_sinfos ();
11071107 } else {
@@ -1158,9 +1158,9 @@ int llama_context::decode(const llama_batch & batch_inp) {
11581158 };
11591159
11601160 int64_t n_outputs_prev = 0 ;
1161- const bool do_mtp_kv_update = batch_inp.update_mtp_kv ;
1162- const bool use_mtp_head = batch_inp.use_mtp_head ;
1163- const bool is_prompt_warmup = batch_inp.is_mtp_prompt_warmup ;
1161+ // const bool do_mtp_kv_update = batch_inp.update_mtp_kv;
1162+ // const bool use_mtp_head = batch_inp.use_mtp_head;
1163+ // const bool is_prompt_warmup = batch_inp.is_mtp_prompt_warmup;
11641164
11651165 do {
11661166 const auto & ubatch = mctx->get_ubatch ();
@@ -1169,13 +1169,13 @@ int llama_context::decode(const llama_batch & batch_inp) {
11691169 for (uint32_t i = 0 ; i < std::min ((uint32_t )5 , ubatch.n_tokens ); ++i) {
11701170 pos_str += std::to_string (ubatch.pos [i]) + " " ;
11711171 }
1172- LLAMA_LOG_WARN (
1173- " [DEBUG-POS] ubatch_size=%u, update_mtp_kv=%s, use_mtp_head=%s. Positions: %s...\n " ,
1174- ubatch.n_tokens ,
1175- batch_inp.update_mtp_kv ? " true" : " false" ,
1176- batch_inp.use_mtp_head ? " true" : " false" ,
1177- pos_str.c_str ()
1178- );
1172+ // LLAMA_LOG_WARN(
1173+ // "[DEBUG-POS] ubatch_size=%u, update_mtp_kv=%s, use_mtp_head=%s. Positions: %s...\n",
1174+ // ubatch.n_tokens,
1175+ // batch_inp.update_mtp_kv ? "true" : "false",
1176+ // batch_inp.use_mtp_head ? "true" : "false",
1177+ // pos_str.c_str()
1178+ // );
11791179 }
11801180
11811181 // count the outputs in this ubatch
@@ -1193,16 +1193,16 @@ int llama_context::decode(const llama_batch & batch_inp) {
11931193 // needs to happen before the graph is built
11941194 n_outputs = n_outputs_new;
11951195 }
1196- if (do_mtp_kv_update) {
1197- LLAMA_LOG_WARN (" [DEBUG-MTP-UPDATE] MTP KV Update ubatch: n_tokens=%d\n " , ubatch.n_tokens );
1198- std::string positions_str;
1199- for (int i = 0 ; i < std::min ((uint32_t )5 , ubatch.n_tokens ); ++i) {
1200- positions_str += std::to_string (ubatch.pos [i]) + " " ;
1201- }
1202- LLAMA_LOG_WARN (" [DEBUG-MTP-UPDATE] Positions: %s...\n " , positions_str.c_str ());
1203- }
1196+ // if (do_mtp_kv_update) {
1197+ // LLAMA_LOG_WARN("[DEBUG-MTP-UPDATE] MTP KV Update ubatch: n_tokens=%d\n", ubatch.n_tokens);
1198+ // std::string positions_str;
1199+ // for (int i = 0; i < std::min((uint32_t)5, ubatch.n_tokens); ++i) {
1200+ // positions_str += std::to_string(ubatch.pos[i]) + " ";
1201+ // }
1202+ // LLAMA_LOG_WARN("[DEBUG-MTP-UPDATE] Positions: %s...\n", positions_str.c_str());
1203+ // }
12041204 ggml_status status;
1205- const auto * res = process_ubatch (ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get (), status, do_mtp_kv_update, use_mtp_head, is_prompt_warmup );
1205+ const auto * res = process_ubatch (ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get (), status, batch_inp. mtp_params );
12061206 if (!res) {
12071207 // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
12081208 llama_pos pos_min[LLAMA_MAX_SEQ];
@@ -1261,17 +1261,17 @@ int llama_context::decode(const llama_batch & batch_inp) {
12611261 }
12621262 }
12631263
1264- if (use_mtp_head) {
1265- if (t_embd != nullptr ) {
1266- LLAMA_LOG_ERROR (" [MTP-GRAPH-BUG] The MTP graph returned an embedding tensor when it shouldn't have! This will cause corruption.\n " );
1267- } else {
1268- LLAMA_LOG_WARN (" [MTP-GRAPH-OK] The MTP graph correctly did not return an embedding tensor.\n " );
1269- }
1270- }
1264+ // if (use_mtp_head) {
1265+ // if (t_embd != nullptr) {
1266+ // LLAMA_LOG_ERROR("[MTP-GRAPH-BUG] The MTP graph returned an embedding tensor when it shouldn't have! This will cause corruption.\n");
1267+ // } else {
1268+ // LLAMA_LOG_WARN("[MTP-GRAPH-OK] The MTP graph correctly did not return an embedding tensor.\n");
1269+ // }
1270+ // }
12711271
12721272 // extract embeddings
12731273 if (t_embd && n_outputs > 0 ) {
1274- if (!use_mtp_head ) {
1274+ if (batch_inp. mtp_params . op_type == MTP_OP_NONE ) {
12751275 ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend (sched.get (), t_embd);
12761276 GGML_ASSERT (backend_embd != nullptr );
12771277
@@ -1389,7 +1389,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
13891389 ggml_backend_sched_reset (sched.get ());
13901390 }
13911391
1392- if (!use_mtp_head ) {
1392+ if (batch_inp. mtp_params . op_type == MTP_OP_NONE ) {
13931393 synchronize ();
13941394 const size_t n_embd = this ->model .hparams .n_embd ;
13951395 double full_buffer_sum = calculate_vector_sum (this ->embd , n_outputs_all * n_embd);
@@ -1534,7 +1534,7 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
15341534
15351535 auto * res = gf_res_reserve.get ();
15361536
1537- const auto gparams = graph_params (res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT, false , false );
1537+ const auto gparams = graph_params (res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT, { MTP_OP_NONE } );
15381538
15391539 res->reset ();
15401540
@@ -1556,8 +1556,7 @@ llm_graph_params llama_context::graph_params(
15561556 const llama_ubatch & ubatch,
15571557 const llama_memory_context_i * mctx,
15581558 llm_graph_type gtype,
1559- bool update_mtp_kv,
1560- bool use_mtp_head) const {
1559+ const llama_mtp_params & mtp_params) const {
15611560 return {
15621561 /* .arch =*/ model.arch ,
15631562 /* .hparams =*/ model.hparams ,
@@ -1570,8 +1569,7 @@ llm_graph_params llama_context::graph_params(
15701569 /* .loras =*/ &loras,
15711570 /* .mctx =*/ mctx,
15721571 /* .cross =*/ &cross,
1573- /* .update_mtp_kv =*/ update_mtp_kv,
1574- /* .use_mtp_head =*/ use_mtp_head,
1572+ /* .mtp_params =*/ mtp_params,
15751573 /* .n_outputs =*/ n_outputs,
15761574 /* .cb =*/ graph_get_cb (),
15771575 /* .res =*/ res,
@@ -2312,7 +2310,7 @@ void llama_context::opt_epoch_iter(
23122310
23132311 auto * res = gf_res_prev.get ();
23142312
2315- const auto gparams = graph_params (res, ubatch, mctx.get (), LLM_GRAPH_TYPE_DEFAULT, false , false );
2313+ const auto gparams = graph_params (res, ubatch, mctx.get (), LLM_GRAPH_TYPE_DEFAULT, { MTP_OP_NONE } );
23162314
23172315 res->reset ();
23182316
0 commit comments