@@ -294,10 +294,7 @@ llama_context::llama_context(
294294        //  TODO: something cleaner
295295        const  auto  n_outputs_save = n_outputs;
296296
297-         //  max number of outputs
298-         n_outputs = n_tokens;
299- 
300-         LLAMA_LOG_DEBUG (" %s: n_tokens = %d, n_seqs = %d, n_outputs = %d\n "  , __func__, n_tokens, n_seqs, n_outputs);
297+         LLAMA_LOG_DEBUG (" %s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n "  , __func__, n_tokens, n_seqs, n_outputs);
301298
302299        int  n_splits_pp = -1 ;
303300        int  n_nodes_pp  = -1 ;
@@ -313,8 +310,15 @@ llama_context::llama_context(
313310        //  reserve pp graph first so that buffers are only allocated once
314311        {
315312            llama_ubatch ubatch_pp = { true , n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr , nullptr , nullptr , nullptr , nullptr };
313+ 
314+             //  max number of outputs
315+             n_outputs = ubatch_pp.n_tokens ;
316+ 
317+             LLAMA_LOG_DEBUG (" %s: reserving graph for n_tokens = %d, n_seqs = %d\n "  , __func__, ubatch_pp.n_tokens , ubatch_pp.n_seqs );
318+ 
316319            auto  * gf = graph_init ();
317320            graph_build (ctx_compute.get (), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
321+ 
318322            if  (!ggml_backend_sched_reserve (sched.get (), gf)) {
319323                throw  std::runtime_error (" failed to allocate compute pp buffers"  );
320324            }
@@ -326,20 +330,33 @@ llama_context::llama_context(
326330        //  reserve with tg graph to get the number of splits and nodes
327331        {
328332            llama_ubatch ubatch_tg = { true , 1 , 1 , n_seqs, &token, nullptr , nullptr , nullptr , nullptr , nullptr };
333+ 
334+             n_outputs = ubatch_tg.n_tokens ;
335+ 
336+             LLAMA_LOG_DEBUG (" %s: reserving graph for n_tokens = %d, n_seqs = %d\n "  , __func__, ubatch_tg.n_tokens , ubatch_tg.n_seqs );
337+ 
329338            auto  * gf = graph_init ();
330339            graph_build (ctx_compute.get (), gf, ubatch_tg, LLM_GRAPH_TYPE_DEFAULT);
340+ 
331341            if  (!ggml_backend_sched_reserve (sched.get (), gf)) {
332342                throw  std::runtime_error (" failed to allocate compute tg buffers"  );
333343            }
344+ 
334345            n_splits_tg = ggml_backend_sched_get_n_splits (sched.get ());
335346            n_nodes_tg  = ggml_graph_n_nodes (gf);
336347        }
337348
338349        //  reserve again with pp graph to avoid ggml-alloc reallocations during inference
339350        {
340351            llama_ubatch ubatch_pp = { true , n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr , nullptr , nullptr , nullptr , nullptr };
352+ 
353+             n_outputs = ubatch_pp.n_tokens ;
354+ 
355+             LLAMA_LOG_DEBUG (" %s: reserving graph for n_tokens = %d, n_seqs = %d\n "  , __func__, ubatch_pp.n_tokens , ubatch_pp.n_seqs );
356+ 
341357            auto  * gf = graph_init ();
342358            graph_build (ctx_compute.get (), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
359+ 
343360            if  (!ggml_backend_sched_reserve (sched.get (), gf)) {
344361                throw  std::runtime_error (" failed to allocate compute pp buffers"  );
345362            }
0 commit comments