@@ -255,7 +255,8 @@ void llama_context::init() {
255
255
// reserve pp graph first so that buffers are only allocated once
256
256
{
257
257
llama_ubatch ubatch_pp = { true , n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr , nullptr , nullptr , nullptr , nullptr };
258
- auto res_pp = graph_build (ubatch_pp, true );
258
+ auto ctx = graph_init ();
259
+ auto res_pp = graph_build (ctx, ubatch_pp, true );
259
260
auto & gf_pp = res_pp.gf ;
260
261
if (!ggml_backend_sched_reserve (sched.get (), gf_pp)) {
261
262
LLAMA_LOG_ERROR (" %s: failed to allocate compute pp buffers\n " , __func__);
@@ -269,7 +270,8 @@ void llama_context::init() {
269
270
// reserve with tg graph to get the number of splits and nodes
270
271
{
271
272
llama_ubatch ubatch_tg = { true , 1 , 1 , n_seqs, &token, nullptr , nullptr , nullptr , nullptr , nullptr };
272
- auto res_tg = graph_build (ubatch_tg, true );
273
+ auto ctx = graph_init ();
274
+ auto res_tg = graph_build (ctx, ubatch_tg, true );
273
275
auto & gf_tg = res_tg.gf ;
274
276
if (!ggml_backend_sched_reserve (sched.get (), gf_tg)) {
275
277
LLAMA_LOG_ERROR (" %s: failed to allocate compute tg buffers\n " , __func__);
@@ -282,7 +284,8 @@ void llama_context::init() {
282
284
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
283
285
{
284
286
llama_ubatch ubatch_pp = { true , n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr , nullptr , nullptr , nullptr , nullptr };
285
- auto res_pp = graph_build (ubatch_pp, true );
287
+ auto ctx = graph_init ();
288
+ auto res_pp = graph_build (ctx, ubatch_pp, true );
286
289
auto & gf_pp = res_pp.gf ;
287
290
if (!ggml_backend_sched_reserve (sched.get (), gf_pp)) {
288
291
LLAMA_LOG_ERROR (" %s: failed to allocate compute pp buffers\n " , __func__);
@@ -569,6 +572,13 @@ ggml_context_ptr llama_context::graph_init() {
569
572
return ggml_context_ptr { ggml_init (params) };
570
573
}
571
574
575
+ llama_graph_result llama_context::graph_build (
576
+ ggml_context_ptr & ctx,
577
+ const llama_ubatch & ubatch,
578
+ bool worst_case) {
579
+ return model.build_graph (ctx, *this , cparams, ubatch, worst_case);
580
+ }
581
+
572
582
enum ggml_status llama_context::graph_compute (
573
583
ggml_cgraph * graph,
574
584
bool batched) {
@@ -907,10 +917,6 @@ void llama_context::build_cb(
907
917
}
908
918
}
909
919
910
- llama_graph_result llama_context::graph_build (const llama_ubatch & ubatch, bool worst_case) {
911
- return model.build_graph (*this , cparams, ubatch, graph_init (), worst_case);
912
- }
913
-
914
920
llama_perf_context_data llama_context::perf_get_data () const {
915
921
llama_perf_context_data data = {};
916
922
@@ -1831,7 +1837,8 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
1831
1837
llama_token token = model.vocab .token_bos (); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
1832
1838
llama_ubatch ubatch = { true , n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr , nullptr , nullptr , nullptr , nullptr };
1833
1839
1834
- auto res = graph_build (ubatch, true );
1840
+ auto ctx = graph_init ();
1841
+ auto res = graph_build (ctx, ubatch, true );
1835
1842
1836
1843
// initialize scheduler with the worst-case graph
1837
1844
ggml_backend_sched_reset (sched.get ());
@@ -1845,7 +1852,8 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
1845
1852
ggml_backend_sched_reset (sched.get ());
1846
1853
ggml_backend_sched_set_eval_callback (sched.get (), cparams.cb_eval , cparams.cb_eval_user_data );
1847
1854
1848
- auto res = graph_build (ubatch, false );
1855
+ auto ctx = graph_init ();
1856
+ auto res = graph_build (ctx, ubatch, false );
1849
1857
1850
1858
auto & gf = res.gf ;
1851
1859
@@ -2092,7 +2100,8 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) {
2092
2100
ggml_backend_sched_reset (sched.get ());
2093
2101
ggml_backend_sched_set_eval_callback (sched.get (), cparams.cb_eval , cparams.cb_eval_user_data );
2094
2102
2095
- auto res = graph_build (ubatch, false );
2103
+ auto ctx = graph_init ();
2104
+ auto res = graph_build (ctx, ubatch, false );
2096
2105
2097
2106
auto & gf = res.gf ;
2098
2107
0 commit comments