@@ -294,10 +294,7 @@ llama_context::llama_context(
294294 // TODO: something cleaner
295295 const auto n_outputs_save = n_outputs;
296296
297- // max number of outputs
298- n_outputs = n_tokens;
299-
300- LLAMA_LOG_DEBUG (" %s: n_tokens = %d, n_seqs = %d, n_outputs = %d\n " , __func__, n_tokens, n_seqs, n_outputs);
297+ LLAMA_LOG_DEBUG (" %s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n " , __func__, n_tokens, n_seqs, n_outputs);
301298
302299 int n_splits_pp = -1 ;
303300 int n_nodes_pp = -1 ;
@@ -313,8 +310,15 @@ llama_context::llama_context(
313310 // reserve pp graph first so that buffers are only allocated once
314311 {
315312 llama_ubatch ubatch_pp = { true , n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr , nullptr , nullptr , nullptr , nullptr };
313+
314+ // max number of outputs
315+ n_outputs = ubatch_pp.n_tokens ;
316+
317+ LLAMA_LOG_DEBUG (" %s: reserving graph for n_tokens = %d, n_seqs = %d\n " , __func__, ubatch_pp.n_tokens , ubatch_pp.n_seqs );
318+
316319 auto * gf = graph_init ();
317320 graph_build (ctx_compute.get (), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
321+
318322 if (!ggml_backend_sched_reserve (sched.get (), gf)) {
319323 throw std::runtime_error (" failed to allocate compute pp buffers" );
320324 }
@@ -326,20 +330,33 @@ llama_context::llama_context(
326330 // reserve with tg graph to get the number of splits and nodes
327331 {
328332 llama_ubatch ubatch_tg = { true , 1 , 1 , n_seqs, &token, nullptr , nullptr , nullptr , nullptr , nullptr };
333+
334+ n_outputs = ubatch_tg.n_tokens ;
335+
336+ LLAMA_LOG_DEBUG (" %s: reserving graph for n_tokens = %d, n_seqs = %d\n " , __func__, ubatch_tg.n_tokens , ubatch_tg.n_seqs );
337+
329338 auto * gf = graph_init ();
330339 graph_build (ctx_compute.get (), gf, ubatch_tg, LLM_GRAPH_TYPE_DEFAULT);
340+
331341 if (!ggml_backend_sched_reserve (sched.get (), gf)) {
332342 throw std::runtime_error (" failed to allocate compute tg buffers" );
333343 }
344+
334345 n_splits_tg = ggml_backend_sched_get_n_splits (sched.get ());
335346 n_nodes_tg = ggml_graph_n_nodes (gf);
336347 }
337348
338349 // reserve again with pp graph to avoid ggml-alloc reallocations during inference
339350 {
340351 llama_ubatch ubatch_pp = { true , n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr , nullptr , nullptr , nullptr , nullptr };
352+
353+ n_outputs = ubatch_pp.n_tokens ;
354+
355+ LLAMA_LOG_DEBUG (" %s: reserving graph for n_tokens = %d, n_seqs = %d\n " , __func__, ubatch_pp.n_tokens , ubatch_pp.n_seqs );
356+
341357 auto * gf = graph_init ();
342358 graph_build (ctx_compute.get (), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
359+
343360 if (!ggml_backend_sched_reserve (sched.get (), gf)) {
344361 throw std::runtime_error (" failed to allocate compute pp buffers" );
345362 }
0 commit comments