@@ -1448,8 +1448,9 @@ llm_graph_params llama_context::graph_params(
14481448}
14491449
14501450llm_graph_params llama_context::mtp_graph_params (
1451- llm_graph_result* res,
1452- const llama_ubatch& ubatch) {
1451+ llm_graph_result * res,
1452+ const llama_ubatch& ubatch,
1453+ const llama_memory_context_i * mctx) {
14531454 size_t n_nodes = std::max<uint32_t >(1024u , 8u * 8u * (((model.hparams .nextn_predict_layers + 1 ) * model.n_tensors ()) / model.hparams .n_layer ));
14541455 ggml_backend_sched_t temp_sched = create_temp_scheduler (n_nodes);
14551456 return {
@@ -1462,14 +1463,29 @@ llm_graph_params llama_context::mtp_graph_params(
14621463 /* .backend_cpu =*/ backend_cpu,
14631464 /* .cvec =*/ &cvec,
14641465 /* .loras =*/ &loras,
1465- /* .mctx =*/ memory-> init_batch (*balloc, 1 , false ). get () ,
1466+ /* .mctx =*/ mctx ,
14661467 /* .cross =*/ &cross,
14671468 /* .n_outputs =*/ 1 ,
14681469 /* .cb =*/ graph_get_cb (temp_sched),
14691470 /* .res =*/ res,
14701471 };
14711472}
14721473
1474+ std::unique_ptr<llama_memory_context_i> llama_context::mtp_memory_batch (const llama_batch& batch_inp) {
1475+ const auto & vocab = model.vocab ;
1476+ const auto & hparams = model.hparams ;
1477+
1478+ const int64_t n_vocab = vocab.n_tokens ();
1479+ const int64_t n_embd = hparams.n_embd ;
1480+
1481+ if (!balloc->init (batch_inp, vocab, memory.get (), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max , false )) {
1482+ LLAMA_LOG_ERROR (" %s: failed to initialize batch\n " , __func__);
1483+ return nullptr ;
1484+ }
1485+
1486+ return memory->init_batch (*balloc, 1 , false );
1487+ }
1488+
14731489ggml_status llama_context::graph_compute (
14741490 ggml_cgraph * gf,
14751491 bool batched) {
@@ -2481,13 +2497,6 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) {
24812497 return ctx->get_embeddings_seq (seq_id);
24822498}
24832499
2484- ggml_tensor * llama_get_embeddings_tensor (llama_context * ctx) {
2485- ctx->synchronize ();
2486-
2487- return ctx->get_embeddings_tensor ();
2488- }
2489-
2490-
24912500// llama adapter API
24922501
24932502int32_t llama_set_adapter_lora (
@@ -2985,42 +2994,43 @@ void llama_opt_epoch(
29852994 callback_eval);
29862995}
29872996
2988- llm_graph_params llama_mtp_graph_params (llama_context* ctx, llm_graph_result* res, const llama_ubatch& ubatch) {
2989- return ctx->mtp_graph_params (res, ubatch);
2990- }
2991-
2992-
2993- ggml_status llama_graph_compute (llama_context* ctx, ggml_cgraph* gf, bool batched) {
2994- return ctx->graph_compute (gf, batched);
2995- }
2996-
29972997void llama_build_and_execute_mtp_graph (struct llama_context * ctx,
2998- ggml_tensor * hidden_state_inp , llama_token last_token_id, int32_t n_past, int32_t last_tok_idx) {
2998+ const llama_batch batch_inp , llama_token last_token_id, int32_t n_past, int32_t last_tok_idx) {
29992999
30003000 const auto * model = llama_get_model (ctx);
30013001
30023002 auto res_mtp = std::make_unique<llm_graph_result>(ctx->graph_max_nodes ());
3003+ llama_memory_context_ptr mctx = ctx->mtp_memory_batch (batch_inp);
3004+ const auto & ubatch_mtp = mctx->get_ubatch ();
30033005
3004- llama_ubatch ubatch_mtp;
3005- ubatch_mtp.n_tokens = 1 ;
3006- ubatch_mtp.pos = &n_past;
3006+ // llama_ubatch ubatch_mtp;
3007+ // ubatch_mtp.n_tokens = 1;
3008+ // ubatch_mtp.pos = &n_past;
30073009
3008- auto params_mtp = std::make_unique<llm_graph_params>(ctx->mtp_graph_params (res_mtp.get (), ubatch_mtp));
3010+ auto params_mtp = std::make_unique<llm_graph_params>(ctx->mtp_graph_params (res_mtp.get (), ubatch_mtp, mctx.get ()));
3011+ ggml_backend_sched_t sched = params_mtp->sched ;
30093012
3010- auto * gf = model-> build_mtp_graph (*params_mtp, hidden_state_inp, last_token_id, n_past );
3013+ auto * last_embd = ctx-> get_embeddings_ith (last_tok_idx );
30113014
3012- ggml_backend_sched_t sched = params_mtp->sched ;
3015+ if (mctx && !mctx->apply ()) {
3016+ LLAMA_LOG_ERROR (" %s: failed to apply memory context\n " , __func__);
3017+ }
3018+
3019+ auto * gf = model->build_mtp_graph (*params_mtp, last_token_id, n_past);
30133020
30143021 ggml_backend_sched_reset (sched); // clear the allocation of the previous graph
30153022 ggml_backend_sched_alloc_graph (sched, gf); // explicitly allocate the new graph but do not execute it
30163023
30173024 ggml_tensor * mtp_token_id_input = ggml_get_tensor (res_mtp->get_ctx (), " mtp_token_id_input" );
3018-
30193025 ggml_backend_tensor_set (mtp_token_id_input, &last_token_id, 0 , sizeof (last_token_id)); // copy data to the newly allocated graph tensors
3026+
3027+ ggml_tensor * mtp_prev_embedding_input = ggml_get_tensor (res_mtp->get_ctx (), " mtp_prev_embedding_input" );
3028+ ggml_backend_tensor_set (mtp_prev_embedding_input, last_embd, 0 , ggml_nbytes (mtp_prev_embedding_input)); // copy data to the newly allocated graph tensors
3029+
30203030 ggml_backend_sched_graph_compute (sched, gf); // execute the graph
30213031
30223032 struct ggml_tensor * logits_mtp = res_mtp->get_logits ();;
3023- LLAMA_LOG_INFO (" logits_mtp pointer address: %p\n " , (void *)logits_mtp);
3033+ // LLAMA_LOG_INFO("logits_mtp pointer address: %p\n", (void*)logits_mtp);
30243034
30253035 if (logits_mtp) {
30263036 ctx->set_logits_ith (logits_mtp, sched, last_tok_idx);
0 commit comments