@@ -255,7 +255,8 @@ void llama_context::init() {
255255 // reserve pp graph first so that buffers are only allocated once
256256 {
257257 llama_ubatch ubatch_pp = { true , n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr , nullptr , nullptr , nullptr , nullptr };
258- auto res_pp = graph_build (ubatch_pp, true );
258+ auto ctx = graph_init ();
259+ auto res_pp = graph_build (ctx, ubatch_pp, true );
259260 auto & gf_pp = res_pp.gf ;
260261 if (!ggml_backend_sched_reserve (sched.get (), gf_pp)) {
261262 LLAMA_LOG_ERROR (" %s: failed to allocate compute pp buffers\n " , __func__);
@@ -269,7 +270,8 @@ void llama_context::init() {
269270 // reserve with tg graph to get the number of splits and nodes
270271 {
271272 llama_ubatch ubatch_tg = { true , 1 , 1 , n_seqs, &token, nullptr , nullptr , nullptr , nullptr , nullptr };
272- auto res_tg = graph_build (ubatch_tg, true );
273+ auto ctx = graph_init ();
274+ auto res_tg = graph_build (ctx, ubatch_tg, true );
273275 auto & gf_tg = res_tg.gf ;
274276 if (!ggml_backend_sched_reserve (sched.get (), gf_tg)) {
275277 LLAMA_LOG_ERROR (" %s: failed to allocate compute tg buffers\n " , __func__);
@@ -282,7 +284,8 @@ void llama_context::init() {
282284 // reserve again with pp graph to avoid ggml-alloc reallocations during inference
283285 {
284286 llama_ubatch ubatch_pp = { true , n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr , nullptr , nullptr , nullptr , nullptr };
285- auto res_pp = graph_build (ubatch_pp, true );
287+ auto ctx = graph_init ();
288+ auto res_pp = graph_build (ctx, ubatch_pp, true );
286289 auto & gf_pp = res_pp.gf ;
287290 if (!ggml_backend_sched_reserve (sched.get (), gf_pp)) {
288291 LLAMA_LOG_ERROR (" %s: failed to allocate compute pp buffers\n " , __func__);
@@ -569,6 +572,13 @@ ggml_context_ptr llama_context::graph_init() {
569572 return ggml_context_ptr { ggml_init (params) };
570573}
571574
575+ llama_graph_result llama_context::graph_build (
576+ ggml_context_ptr & ctx,
577+ const llama_ubatch & ubatch,
578+ bool worst_case) {
579+ return model.build_graph (ctx, *this , cparams, ubatch, worst_case);
580+ }
581+
572582enum ggml_status llama_context::graph_compute (
573583 ggml_cgraph * graph,
574584 bool batched) {
@@ -907,10 +917,6 @@ void llama_context::build_cb(
907917 }
908918}
909919
910- llama_graph_result llama_context::graph_build (const llama_ubatch & ubatch, bool worst_case) {
911- return model.build_graph (*this , cparams, ubatch, graph_init (), worst_case);
912- }
913-
914920llama_perf_context_data llama_context::perf_get_data () const {
915921 llama_perf_context_data data = {};
916922
@@ -1831,7 +1837,8 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
18311837 llama_token token = model.vocab .token_bos (); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
18321838 llama_ubatch ubatch = { true , n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr , nullptr , nullptr , nullptr , nullptr };
18331839
1834- auto res = graph_build (ubatch, true );
1840+ auto ctx = graph_init ();
1841+ auto res = graph_build (ctx, ubatch, true );
18351842
18361843 // initialize scheduler with the worst-case graph
18371844 ggml_backend_sched_reset (sched.get ());
@@ -1845,7 +1852,8 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
18451852 ggml_backend_sched_reset (sched.get ());
18461853 ggml_backend_sched_set_eval_callback (sched.get (), cparams.cb_eval , cparams.cb_eval_user_data );
18471854
1848- auto res = graph_build (ubatch, false );
1855+ auto ctx = graph_init ();
1856+ auto res = graph_build (ctx, ubatch, false );
18491857
18501858 auto & gf = res.gf ;
18511859
@@ -2092,7 +2100,8 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) {
20922100 ggml_backend_sched_reset (sched.get ());
20932101 ggml_backend_sched_set_eval_callback (sched.get (), cparams.cb_eval , cparams.cb_eval_user_data );
20942102
2095- auto res = graph_build (ubatch, false );
2103+ auto ctx = graph_init ();
2104+ auto res = graph_build (ctx, ubatch, false );
20962105
20972106 auto & gf = res.gf ;
20982107
0 commit comments