Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,14 +296,13 @@ llama_context::llama_context(

// reserve pp (prompt processing) graph first so that buffers are only allocated once
{
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
const bool resolve_fa_auto = params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO;
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), /*alloc =*/ resolve_fa_auto);
if (!gf) {
throw std::runtime_error("failed to allocate compute pp buffers");
}

if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
ggml_backend_sched_alloc_graph(sched.get(), gf);

if (resolve_fa_auto) {
const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1;
bool fa_device_mismatch = false;
for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
Expand Down Expand Up @@ -1366,7 +1365,7 @@ llm_graph_result * llama_context::get_gf_res_reserve() const {
return static_cast<llm_graph_result *>(gf_res_reserve.get());
}

ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool alloc) {
LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);

if (n_tokens % n_seqs != 0) {
Expand Down Expand Up @@ -1401,9 +1400,16 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
this->n_outputs = save_n_outputs;

// initialize scheduler with the specified graph
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
return nullptr;
if (alloc) {
if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
return nullptr;
}
} else {
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
return nullptr;
}
}

return gf;
Expand Down
3 changes: 2 additions & 1 deletion src/llama-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,8 @@ struct llama_context {
ggml_status graph_compute(ggml_cgraph * gf, bool batched);

// reserve a graph with a dummy ubatch of the specified size
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx);
// optionally allocate the graph so that tensor backend assignments can be retrieved
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool alloc = false);

private:
llm_graph_params graph_params(
Expand Down
Loading