diff --git a/src/llama-context.cpp b/src/llama-context.cpp index ac8453ab741d4..ffc6e5e8974b1 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -296,14 +296,13 @@ llama_context::llama_context( // reserve pp (prompt processing) graph first so that buffers are only allocated once { - auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); + const bool resolve_fa_auto = params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO; + auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), /*alloc =*/ resolve_fa_auto); if (!gf) { throw std::runtime_error("failed to allocate compute pp buffers"); } - if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) { - ggml_backend_sched_alloc_graph(sched.get(), gf); - + if (resolve_fa_auto) { const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1; bool fa_device_mismatch = false; for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { @@ -1366,7 +1365,7 @@ llm_graph_result * llama_context::get_gf_res_reserve() const { return static_cast(gf_res_reserve.get()); } -ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) { +ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool alloc) { LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs); if (n_tokens % n_seqs != 0) { @@ -1401,9 +1400,16 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u this->n_outputs = save_n_outputs; // initialize scheduler with the specified graph - if (!ggml_backend_sched_reserve(sched.get(), gf)) { - LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__); - return nullptr; + if (alloc) { + if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__); + return nullptr; + } + } else { + if (!ggml_backend_sched_reserve(sched.get(), gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__); + return nullptr; + } } return gf; diff --git a/src/llama-context.h b/src/llama-context.h index a372bcfbe41aa..89ab9d15ce51c 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -196,7 +196,8 @@ struct llama_context { ggml_status graph_compute(ggml_cgraph * gf, bool batched); // reserve a graph with a dummy ubatch of the specified size - ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx); + // optionally allocate the graph so that tensor backend assignments can be retrieved + ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool alloc = false); private: llm_graph_params graph_params(