@@ -101,7 +101,8 @@ llama_context::llama_context(
101101
102102 cparams.n_ubatch = std::min (cparams.n_batch , params.n_ubatch == 0 ? params.n_batch : params.n_ubatch );
103103
104- cparams.op_offload = params.op_offload ;
104+ cparams.op_offload = params.op_offload ;
105+ cparams.graph_reuse = params.graph_reuse ;
105106
106107 const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max ;
107108
@@ -227,8 +228,8 @@ llama_context::llama_context(
227228
228229 LLAMA_LOG_DEBUG (" %s: max_nodes = %zu\n " , __func__, max_nodes);
229230
230- // buffer used to store the computation graph and the tensor meta data
231- buf_compute_meta. resize ( ggml_tensor_overhead ()*max_nodes + ggml_graph_overhead_custom (max_nodes, false ));
231+ gf_res_prev. reset ( new llm_graph_result (max_nodes));
232+ gf_res_reserve. reset ( new llm_graph_result (max_nodes));
232233
233234 // TODO: move these checks to ggml_backend_sched
234235 // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
@@ -388,10 +389,6 @@ ggml_backend_sched_t llama_context::get_sched() const {
388389 return sched.get ();
389390}
390391
391- ggml_context * llama_context::get_ctx_compute () const {
392- return ctx_compute.get ();
393- }
394-
395392uint32_t llama_context::n_ctx () const {
396393 return cparams.n_ctx ;
397394}
@@ -678,38 +675,50 @@ bool llama_context::apply_adapter_cvec(
678675 return cvec.apply (model, data, len, n_embd, il_start, il_end);
679676}
680677
681- llm_graph_result_ptr llama_context::process_ubatch (const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
678+ llm_graph_result_i * llama_context::process_ubatch (const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
682679 if (mctx && !mctx->apply ()) {
683680 LLAMA_LOG_ERROR (" %s: failed to apply memory context\n " , __func__);
684681 ret = GGML_STATUS_FAILED;
685682 return nullptr ;
686683 }
687684
688- auto * gf = graph_init ();
689- if (!gf) {
690- LLAMA_LOG_ERROR (" %s: failed to initialize graph\n " , __func__);
691- ret = GGML_STATUS_FAILED;
692- return nullptr ;
693- }
685+ auto * res = gf_res_prev.get ();
686+ auto * gf = res->get_gf ();
694687
695- auto res = graph_build (ctx_compute.get (), gf, ubatch, gtype, mctx);
696- if (!res) {
697- LLAMA_LOG_ERROR (" %s: failed to build graph\n " , __func__);
698- ret = GGML_STATUS_FAILED;
699- return nullptr ;
700- }
688+ const auto gparams = graph_params (res, ubatch, mctx, gtype);
701689
702- // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
690+ const bool can_reuse = cparams.graph_reuse && res->update (gparams);
691+ if (can_reuse) {
692+ LLAMA_LOG_DEBUG (" %s: reusing previous graph\n " , __func__);
693+ n_reused++;
694+ } else {
695+ res->reset ();
703696
704- if (!ggml_backend_sched_alloc_graph (sched.get (), gf)) {
705- LLAMA_LOG_ERROR (" %s: failed to allocate graph\n " , __func__);
706- ret = GGML_STATUS_ALLOC_FAILED;
707- return nullptr ;
697+ ggml_backend_sched_reset (sched.get ());
698+ ggml_backend_sched_set_eval_callback (sched.get (), cparams.cb_eval , cparams.cb_eval_user_data );
699+
700+ // const auto t_start_us = ggml_time_us();
701+
702+ gf = model.build_graph (gparams);
703+
704+ // LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
705+
706+ if (!gf) {
707+ LLAMA_LOG_ERROR (" %s: failed to initialize graph\n " , __func__);
708+ ret = GGML_STATUS_FAILED;
709+ return nullptr ;
710+ }
711+
712+ if (!ggml_backend_sched_alloc_graph (sched.get (), gf)) {
713+ LLAMA_LOG_ERROR (" %s: failed to allocate graph\n " , __func__);
714+ ret = GGML_STATUS_ALLOC_FAILED;
715+ return nullptr ;
716+ }
708717 }
709718
710719 res->set_inputs (&ubatch);
711720
712- const auto status = graph_compute (gf , ubatch.n_tokens > 1 );
721+ const auto status = graph_compute (res-> get_gf () , ubatch.n_tokens > 1 );
713722 if (status != GGML_STATUS_SUCCESS) {
714723 LLAMA_LOG_ERROR (" %s: failed to compute graph, compute status: %d\n " , __func__, status);
715724 ret = status;
@@ -767,6 +776,8 @@ int llama_context::encode(const llama_batch & batch_inp) {
767776
768777 n_outputs = n_tokens;
769778
779+ gf_res_prev->reset ();
780+
770781 ggml_backend_sched_reset (sched.get ());
771782 ggml_backend_sched_set_eval_callback (sched.get (), cparams.cb_eval , cparams.cb_eval_user_data );
772783
@@ -778,7 +789,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
778789 cparams.causal_attn = false ;
779790
780791 ggml_status status;
781- const auto res = process_ubatch (ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr , status);
792+ const auto * res = process_ubatch (ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr , status);
782793
783794 cparams.causal_attn = causal_attn_org;
784795
@@ -846,7 +857,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
846857
847858 // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
848859 // overlap with device computation.
849- ggml_backend_sched_reset (sched.get ());
860+ // ggml_backend_sched_reset(sched.get());
850861
851862 // TODO: hacky solution
852863 if (model.arch == LLM_ARCH_T5 && t_embd) {
@@ -1005,11 +1016,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
10051016 n_outputs = n_outputs_new;
10061017 }
10071018
1008- ggml_backend_sched_reset (sched.get ());
1009- ggml_backend_sched_set_eval_callback (sched.get (), cparams.cb_eval , cparams.cb_eval_user_data );
1010-
10111019 ggml_status status;
1012- const auto res = process_ubatch (ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get (), status);
1020+ const auto * res = process_ubatch (ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get (), status);
10131021
10141022 if (!res) {
10151023 // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
@@ -1192,7 +1200,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
11921200
11931201 // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
11941202 // overlap with device computation.
1195- ggml_backend_sched_reset (sched.get ());
1203+ // ggml_backend_sched_reset(sched.get());
11961204
11971205 return 0 ;
11981206}
@@ -1275,20 +1283,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
12751283// graph
12761284//
12771285
1278- int32_t llama_context::graph_max_nodes () const {
1279- return std::max<int32_t >(65536 , 5 *model.n_tensors ());
1280- }
1281-
1282- ggml_cgraph * llama_context::graph_init () {
1283- ggml_init_params params = {
1284- /* .mem_size =*/ buf_compute_meta.size (),
1285- /* .mem_buffer =*/ buf_compute_meta.data (),
1286- /* .no_alloc =*/ true ,
1287- };
1288-
1289- ctx_compute.reset (ggml_init (params));
1290-
1291- return ggml_new_graph_custom (ctx_compute.get (), graph_max_nodes (), false );
1286+ uint32_t llama_context::graph_max_nodes () const {
1287+ return std::max<uint32_t >(65536u , 5u *model.n_tensors ());
12921288}
12931289
12941290ggml_cgraph * llama_context::graph_reserve (uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
@@ -1301,6 +1297,9 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
13011297 LLAMA_LOG_DEBUG (" %s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n " , __func__, n_tokens, n_seqs, n_outputs);
13021298 }
13031299
1300+ gf_res_prev->reset ();
1301+ ggml_backend_sched_reset (sched.get ());
1302+
13041303 // store the n_outputs as it is, and restore it afterwards
13051304 // TODO: not sure if needed, might simplify in the future by removing this
13061305 const auto save_n_outputs = this ->n_outputs ;
@@ -1310,17 +1309,15 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
13101309 llama_batch_allocr balloc (model.hparams .n_pos_per_embd ());
13111310 llama_ubatch ubatch = balloc.ubatch_reserve (n_tokens/n_seqs, n_seqs);
13121311
1313- auto * gf = graph_init ();
1314- auto res = graph_build (ctx_compute.get (), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx);
1312+ auto * res = gf_res_reserve.get ();
13151313
1316- this -> n_outputs = save_n_outputs ;
1314+ const auto gparams = graph_params (res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT) ;
13171315
1318- if (!res) {
1319- LLAMA_LOG_ERROR (" %s: failed to build worst-case graph\n " , __func__);
1320- return nullptr ;
1321- }
1316+ res->reset ();
13221317
1323- ggml_backend_sched_reset (sched.get ());
1318+ auto * gf = model.build_graph (gparams);
1319+
1320+ this ->n_outputs = save_n_outputs;
13241321
13251322 // initialize scheduler with the specified graph
13261323 if (!ggml_backend_sched_reserve (sched.get (), gf)) {
@@ -1331,28 +1328,27 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
13311328 return gf;
13321329}
13331330
1334- llm_graph_result_ptr llama_context::graph_build (
1335- ggml_context * ctx,
1336- ggml_cgraph * gf,
1337- const llama_ubatch & ubatch,
1338- llm_graph_type gtype,
1339- const llama_memory_context_i * mctx) {
1340- return model.build_graph (
1341- {
1342- /* .ctx =*/ ctx,
1343- /* .arch =*/ model.arch ,
1344- /* .hparams =*/ model.hparams ,
1345- /* .cparams =*/ cparams,
1346- /* .ubatch =*/ ubatch,
1347- /* .sched =*/ sched.get (),
1348- /* .backend_cpu =*/ backend_cpu,
1349- /* .cvec =*/ &cvec,
1350- /* .loras =*/ &loras,
1351- /* .mctx =*/ mctx,
1352- /* .cross =*/ &cross,
1353- /* .n_outputs =*/ n_outputs,
1354- /* .cb =*/ graph_get_cb (),
1355- }, gf, gtype);
1331+ llm_graph_params llama_context::graph_params (
1332+ llm_graph_result_i * res,
1333+ const llama_ubatch & ubatch,
1334+ const llama_memory_context_i * mctx,
1335+ llm_graph_type gtype) const {
1336+ return {
1337+ /* .arch =*/ model.arch ,
1338+ /* .hparams =*/ model.hparams ,
1339+ /* .cparams =*/ cparams,
1340+ /* .ubatch =*/ ubatch,
1341+ /* .gtype =*/ gtype,
1342+ /* .sched =*/ sched.get (),
1343+ /* .backend_cpu =*/ backend_cpu,
1344+ /* .cvec =*/ &cvec,
1345+ /* .loras =*/ &loras,
1346+ /* .mctx =*/ mctx,
1347+ /* .cross =*/ &cross,
1348+ /* .n_outputs =*/ n_outputs,
1349+ /* .cb =*/ graph_get_cb (),
1350+ /* .res =*/ res,
1351+ };
13561352}
13571353
13581354ggml_status llama_context::graph_compute (
@@ -1930,6 +1926,7 @@ llama_perf_context_data llama_context::perf_get_data() const {
19301926 data.t_eval_ms = 1e-3 * t_eval_us;
19311927 data.n_p_eval = std::max (1 , n_p_eval);
19321928 data.n_eval = std::max (1 , n_eval);
1929+ data.n_reused = std::max (0 , n_reused);
19331930
19341931 return data;
19351932}
@@ -1938,6 +1935,7 @@ void llama_context::perf_reset() {
19381935 t_start_us = ggml_time_us ();
19391936 t_eval_us = n_eval = 0 ;
19401937 t_p_eval_us = n_p_eval = 0 ;
1938+ n_reused = 0 ;
19411939}
19421940
19431941//
@@ -2064,8 +2062,13 @@ void llama_context::opt_epoch_iter(
20642062 break ;
20652063 }
20662064
2067- auto * gf = graph_init ();
2068- auto res = graph_build (ctx_compute.get (), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx.get ());
2065+ auto * res = gf_res_prev.get ();
2066+
2067+ const auto gparams = graph_params (res, ubatch, mctx.get (), LLM_GRAPH_TYPE_DEFAULT);
2068+
2069+ res->reset ();
2070+
2071+ auto * gf = model.build_graph (gparams);
20692072
20702073 struct ggml_context * ctx_compute_opt;
20712074 {
@@ -2187,6 +2190,7 @@ llama_context_params llama_context_default_params() {
21872190 /* .no_perf =*/ true ,
21882191 /* .op_offload =*/ true ,
21892192 /* .swa_full =*/ true ,
2193+ /* .graph_reuse =*/ false ,
21902194 };
21912195
21922196 return result;
@@ -2807,6 +2811,7 @@ void llama_perf_context_print(const llama_context * ctx) {
28072811 LLAMA_LOG_INFO (" %s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n " ,
28082812 __func__, data.t_eval_ms , data.n_eval , data.t_eval_ms / data.n_eval , 1e3 / data.t_eval_ms * data.n_eval );
28092813 LLAMA_LOG_INFO (" %s: total time = %10.2f ms / %5d tokens\n " , __func__, (t_end_ms - data.t_start_ms ), (data.n_p_eval + data.n_eval ));
2814+ LLAMA_LOG_INFO (" %s: graphs reused = %10d\n " , __func__, data.n_reused );
28102815}
28112816
28122817void llama_perf_context_reset (llama_context * ctx) {
0 commit comments