@@ -227,8 +227,14 @@ llama_context::llama_context(
227227
228228 LLAMA_LOG_DEBUG (" %s: max_nodes = %zu\n " , __func__, max_nodes);
229229
230- // buffer used to store the computation graph and the tensor meta data
231- buf_compute_meta.resize (ggml_tensor_overhead ()*max_nodes + ggml_graph_overhead_custom (max_nodes, false ));
230+ // buffers used to store the computation graph and the tensor meta data
231+ for (auto & res : gf_res) {
232+ res.reset (new llm_graph_result ());
233+ res->reserve (max_nodes);
234+ };
235+
236+ gf_res_reserve.reset (new llm_graph_result ());
237+ gf_res_reserve->reserve (max_nodes);
232238
233239 // TODO: move these checks to ggml_backend_sched
234240 // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
@@ -388,10 +394,6 @@ ggml_backend_sched_t llama_context::get_sched() const {
388394 return sched.get ();
389395}
390396
391- ggml_context * llama_context::get_ctx_compute () const {
392- return ctx_compute.get ();
393- }
394-
395397uint32_t llama_context::n_ctx () const {
396398 return cparams.n_ctx ;
397399}
@@ -678,36 +680,40 @@ bool llama_context::apply_adapter_cvec(
678680 return cvec.apply (model, data, len, n_embd, il_start, il_end);
679681}
680682
681- llm_graph_result_ptr llama_context::process_ubatch (const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
683+ llm_graph_result_i * llama_context::process_ubatch (const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
682684 if (mctx && !mctx->apply ()) {
683685 LLAMA_LOG_ERROR (" %s: failed to apply memory context\n " , __func__);
684686 ret = GGML_STATUS_FAILED;
685687 return nullptr ;
686688 }
687689
688- auto * gf = graph_init ();
690+ gf_res_next ()->init ();
691+
692+ auto * gf = gf_res_cur ()->get_gf ();
689693 if (!gf) {
690694 LLAMA_LOG_ERROR (" %s: failed to initialize graph\n " , __func__);
691695 ret = GGML_STATUS_FAILED;
692696 return nullptr ;
693697 }
694698
695- auto res = graph_build (ctx_compute. get (), gf , ubatch, gtype, mctx);
696- if (!res ) {
697- LLAMA_LOG_ERROR (" %s: failed to build graph\n " , __func__);
698- ret = GGML_STATUS_FAILED ;
699- return nullptr ;
700- }
699+ const bool can_reuse = graph_build (gf_res_cur (), gf_res_prv () , ubatch, gtype, mctx);
700+ if (can_reuse ) {
701+ LLAMA_LOG_DEBUG (" %s: reusing previous graph\n " , __func__);
702+ gf_res_next ()-> update (mctx) ;
703+ } else {
704+ // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
701705
702- // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
706+ ggml_backend_sched_reset (sched.get ());
707+ ggml_backend_sched_set_eval_callback (sched.get (), cparams.cb_eval , cparams.cb_eval_user_data );
703708
704- if (!ggml_backend_sched_alloc_graph (sched.get (), gf)) {
705- LLAMA_LOG_ERROR (" %s: failed to allocate graph\n " , __func__);
706- ret = GGML_STATUS_ALLOC_FAILED;
707- return nullptr ;
709+ if (!ggml_backend_sched_alloc_graph (sched.get (), gf)) {
710+ LLAMA_LOG_ERROR (" %s: failed to allocate graph\n " , __func__);
711+ ret = GGML_STATUS_ALLOC_FAILED;
712+ return nullptr ;
713+ }
708714 }
709715
710- res ->set_inputs (&ubatch);
716+ gf_res_cur () ->set_inputs (&ubatch);
711717
712718 const auto status = graph_compute (gf, ubatch.n_tokens > 1 );
713719 if (status != GGML_STATUS_SUCCESS) {
@@ -718,7 +724,7 @@ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch,
718724
719725 ret = GGML_STATUS_SUCCESS;
720726
721- return res ;
727+ return gf_res_cur () ;
722728}
723729
724730int llama_context::encode (const llama_batch & batch_inp) {
@@ -767,6 +773,8 @@ int llama_context::encode(const llama_batch & batch_inp) {
767773
768774 n_outputs = n_tokens;
769775
776+ // TODO: when resetting the scheduler, clear prev graph buffers
777+ gf_res_next ()->init ();
770778 ggml_backend_sched_reset (sched.get ());
771779 ggml_backend_sched_set_eval_callback (sched.get (), cparams.cb_eval , cparams.cb_eval_user_data );
772780
@@ -778,7 +786,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
778786 cparams.causal_attn = false ;
779787
780788 ggml_status status;
781- const auto res = process_ubatch (ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr , status);
789+ const auto * res = process_ubatch (ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr , status);
782790
783791 cparams.causal_attn = causal_attn_org;
784792
@@ -846,7 +854,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
846854
847855 // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
848856 // overlap with device computation.
849- ggml_backend_sched_reset (sched.get ());
857+ // ggml_backend_sched_reset(sched.get());
850858
851859 // TODO: hacky solution
852860 if (model.arch == LLM_ARCH_T5 && t_embd) {
@@ -1005,11 +1013,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
10051013 n_outputs = n_outputs_new;
10061014 }
10071015
1008- ggml_backend_sched_reset (sched.get ());
1009- ggml_backend_sched_set_eval_callback (sched.get (), cparams.cb_eval , cparams.cb_eval_user_data );
1010-
10111016 ggml_status status;
1012- const auto res = process_ubatch (ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get (), status);
1017+ const auto * res = process_ubatch (ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get (), status);
10131018
10141019 if (!res) {
10151020 // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
@@ -1192,7 +1197,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
11921197
11931198 // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
11941199 // overlap with device computation.
1195- ggml_backend_sched_reset (sched.get ());
1200+ // ggml_backend_sched_reset(sched.get());
11961201
11971202 return 0 ;
11981203}
@@ -1279,18 +1284,6 @@ int32_t llama_context::graph_max_nodes() const {
12791284 return std::max<int32_t >(65536 , 5 *model.n_tensors ());
12801285}
12811286
1282- ggml_cgraph * llama_context::graph_init () {
1283- ggml_init_params params = {
1284- /* .mem_size =*/ buf_compute_meta.size (),
1285- /* .mem_buffer =*/ buf_compute_meta.data (),
1286- /* .no_alloc =*/ true ,
1287- };
1288-
1289- ctx_compute.reset (ggml_init (params));
1290-
1291- return ggml_new_graph_custom (ctx_compute.get (), graph_max_nodes (), false );
1292- }
1293-
12941287ggml_cgraph * llama_context::graph_reserve (uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
12951288 LLAMA_LOG_DEBUG (" %s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n " , __func__, n_tokens, n_seqs, n_outputs);
12961289
@@ -1301,6 +1294,10 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
13011294 LLAMA_LOG_DEBUG (" %s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n " , __func__, n_tokens, n_seqs, n_outputs);
13021295 }
13031296
1297+ // TODO: when resetting the scheduler, clear prev graph buffers
1298+ gf_res_next ()->init ();
1299+ ggml_backend_sched_reset (sched.get ());
1300+
13041301 // store the n_outputs as it is, and restore it afterwards
13051302 // TODO: not sure if needed, might simplify in the future by removing this
13061303 const auto save_n_outputs = this ->n_outputs ;
@@ -1310,17 +1307,13 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
13101307 llama_batch_allocr balloc (model.hparams .n_pos_per_embd ());
13111308 llama_ubatch ubatch = balloc.ubatch_reserve (n_tokens/n_seqs, n_seqs);
13121309
1313- auto * gf = graph_init ();
1314- auto res = graph_build (ctx_compute. get (), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx );
1310+ gf_res_reserve-> init ();
1311+ auto * gf = gf_res_reserve-> get_gf ( );
13151312
1316- this ->n_outputs = save_n_outputs;
1317-
1318- if (!res) {
1319- LLAMA_LOG_ERROR (" %s: failed to build worst-case graph\n " , __func__);
1320- return nullptr ;
1321- }
1313+ const bool can_reuse = graph_build (gf_res_reserve.get (), nullptr , ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx);
1314+ GGML_ASSERT (!can_reuse); // cannot reuse reserve graphs
13221315
1323- ggml_backend_sched_reset (sched. get ()) ;
1316+ this -> n_outputs = save_n_outputs ;
13241317
13251318 // initialize scheduler with the specified graph
13261319 if (!ggml_backend_sched_reserve (sched.get (), gf)) {
@@ -1331,15 +1324,17 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
13311324 return gf;
13321325}
13331326
1334- llm_graph_result_ptr llama_context::graph_build (
1335- ggml_context * ctx ,
1336- ggml_cgraph * gf ,
1327+ bool llama_context::graph_build (
1328+ llm_graph_result_i * gf_res_cur ,
1329+ llm_graph_result_i * gf_res_prv ,
13371330 const llama_ubatch & ubatch,
13381331 llm_graph_type gtype,
13391332 const llama_memory_context_i * mctx) {
13401333 return model.build_graph (
13411334 {
1342- /* .ctx =*/ ctx,
1335+ /* .ctx =*/ gf_res_cur->get_ctx (),
1336+ /* .gf_res_cur =*/ static_cast <llm_graph_result *>(gf_res_cur),
1337+ /* .gf_res_prv =*/ static_cast <llm_graph_result *>(gf_res_prv),
13431338 /* .arch =*/ model.arch ,
13441339 /* .hparams =*/ model.hparams ,
13451340 /* .cparams =*/ cparams,
@@ -1352,7 +1347,7 @@ llm_graph_result_ptr llama_context::graph_build(
13521347 /* .cross =*/ &cross,
13531348 /* .n_outputs =*/ n_outputs,
13541349 /* .cb =*/ graph_get_cb (),
1355- }, gf, gtype);
1350+ }, gtype);
13561351}
13571352
13581353ggml_status llama_context::graph_compute (
@@ -2064,8 +2059,11 @@ void llama_context::opt_epoch_iter(
20642059 break ;
20652060 }
20662061
2067- auto * gf = graph_init ();
2068- auto res = graph_build (ctx_compute.get (), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx.get ());
2062+ gf_res_cur ()->init ();
2063+ auto * gf = gf_res_cur ()->get_gf ();
2064+
2065+ const bool can_reuse = graph_build (gf_res_cur (), nullptr , ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx.get ());
2066+ GGML_ASSERT (!can_reuse); // cannot reuse optimization graphs
20692067
20702068 struct ggml_context * ctx_compute_opt;
20712069 {
@@ -2078,10 +2076,10 @@ void llama_context::opt_epoch_iter(
20782076 };
20792077 ctx_compute_opt = ggml_init (params);
20802078 }
2081- ggml_opt_prepare_alloc (opt_ctx, ctx_compute_opt, gf, res ->get_tokens (), res ->get_logits ());
2079+ ggml_opt_prepare_alloc (opt_ctx, ctx_compute_opt, gf, gf_res_cur () ->get_tokens (), gf_res_cur () ->get_logits ());
20822080 ggml_opt_alloc (opt_ctx, train);
20832081
2084- res ->set_inputs (&ubatch);
2082+ gf_res_cur () ->set_inputs (&ubatch);
20852083 {
20862084 struct ggml_tensor * labels = ggml_opt_labels (opt_ctx);
20872085 GGML_ASSERT (labels->ne [1 ] == n_ubatch);
0 commit comments