Skip to content

Commit 9777032

Browse files
authored
llama : separate compute buffer reserve from fattn check (ggml-org#15696)
Exposes ggml_backend_sched_split_graph() to allow splitting the graph without allocating compute buffers and uses it to split the graph for the automatic Flash Attention check.
1 parent 7d3c9f2 commit 9777032

File tree

4 files changed

+64
-58
lines changed

4 files changed

+64
-58
lines changed

ggml/include/ggml-backend.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,9 @@ extern "C" {
307307
GGML_API void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
308308
GGML_API ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node);
309309

310+
// Split graph without allocating it
311+
GGML_API void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
312+
310313
// Allocate and compute graph on the backend scheduler
311314
GGML_API bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph); // returns success
312315
GGML_API enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph);

ggml/src/ggml-backend.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -902,7 +902,7 @@ static void ggml_backend_sched_set_if_supported(ggml_backend_sched_t sched, stru
902902
}
903903

904904
// assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend
905-
static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
905+
void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
906906
// reset splits
907907
sched->n_splits = 0;
908908
sched->n_graph_inputs = 0;
@@ -1687,6 +1687,8 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph *
16871687
GGML_ASSERT(sched);
16881688
GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes + measure_graph->n_leafs);
16891689

1690+
ggml_backend_sched_reset(sched);
1691+
16901692
ggml_backend_sched_synchronize(sched);
16911693

16921694
ggml_backend_sched_split_graph(sched, measure_graph);

src/llama-context.cpp

Lines changed: 57 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -270,19 +270,7 @@ llama_context::llama_context(
270270
}
271271
}
272272

273-
// resolve automatic Flash Attention use and reserve worst-case graph
274273
if (!hparams.vocab_only) {
275-
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
276-
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
277-
278-
LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
279-
280-
int n_splits_pp = -1;
281-
int n_nodes_pp = -1;
282-
283-
int n_splits_tg = -1;
284-
int n_nodes_tg = -1;
285-
286274
llama_memory_context_ptr mctx;
287275
if (memory) {
288276
LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__);
@@ -293,6 +281,59 @@ llama_context::llama_context(
293281
}
294282

295283
cross.v_embd.clear();
284+
// resolve automatic Flash Attention use
285+
if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
286+
auto * gf = graph_reserve(1, 1, 0, mctx.get(), true);
287+
if (!gf) {
288+
throw std::runtime_error("failed to split graph for Flash Attention check");
289+
}
290+
291+
const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1;
292+
bool fa_device_mismatch = false;
293+
for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
294+
ggml_tensor * n = ggml_graph_node(gf, i);
295+
if (n->op != GGML_OP_FLASH_ATTN_EXT) {
296+
continue;
297+
}
298+
ggml_backend_dev_t device_fa = ggml_backend_get_device(
299+
ggml_backend_sched_get_tensor_backend(sched.get(), n));
300+
301+
// TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer
302+
GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0);
303+
const int il = std::stoi(n->name + prefix_len);
304+
ggml_backend_dev_t device_kv = model.dev_layer(il);
305+
if (device_fa != device_kv) {
306+
LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor "
307+
"is assigned to device %s (usually due to missing support)\n",
308+
__func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa));
309+
// FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways
310+
fa_device_mismatch = true;
311+
break;
312+
}
313+
}
314+
if (fa_device_mismatch) {
315+
cparams.flash_attn = false;
316+
LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__);
317+
if (ggml_is_quantized(params.type_v)) {
318+
throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention");
319+
}
320+
} else {
321+
cparams.flash_attn = true;
322+
LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__);
323+
}
324+
}
325+
326+
// reserve worst-case graph
327+
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
328+
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
329+
330+
LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
331+
332+
int n_splits_pp = -1;
333+
int n_nodes_pp = -1;
334+
335+
int n_splits_tg = -1;
336+
int n_nodes_tg = -1;
296337

297338
// reserve pp (prompt processing) graph first so that buffers are only allocated once
298339
{
@@ -301,48 +342,6 @@ llama_context::llama_context(
301342
throw std::runtime_error("failed to allocate compute pp buffers");
302343
}
303344

304-
if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
305-
ggml_backend_sched_alloc_graph(sched.get(), gf);
306-
307-
const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1;
308-
bool fa_device_mismatch = false;
309-
for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
310-
ggml_tensor * n = ggml_graph_node(gf, i);
311-
if (n->op != GGML_OP_FLASH_ATTN_EXT) {
312-
continue;
313-
}
314-
ggml_backend_dev_t device_fa = ggml_backend_get_device(
315-
ggml_backend_sched_get_tensor_backend(sched.get(), n));
316-
317-
// TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer
318-
GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0);
319-
const int il = std::stoi(n->name + prefix_len);
320-
ggml_backend_dev_t device_kv = model.dev_layer(il);
321-
if (device_fa != device_kv) {
322-
LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor "
323-
"is assigned to device %s (usually due to missing support)\n",
324-
__func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa));
325-
// FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways
326-
fa_device_mismatch = true;
327-
break;
328-
}
329-
}
330-
if (fa_device_mismatch) {
331-
cparams.flash_attn = false;
332-
LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__);
333-
if (ggml_is_quantized(params.type_v)) {
334-
throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention");
335-
}
336-
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
337-
if (!gf) {
338-
throw std::runtime_error("failed to allocate compute pp buffers");
339-
}
340-
} else {
341-
cparams.flash_attn = true;
342-
LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__);
343-
}
344-
}
345-
346345
n_splits_pp = ggml_backend_sched_get_n_splits(sched.get());
347346
n_nodes_pp = ggml_graph_n_nodes(gf);
348347
}
@@ -1366,7 +1365,7 @@ llm_graph_result * llama_context::get_gf_res_reserve() const {
13661365
return static_cast<llm_graph_result *>(gf_res_reserve.get());
13671366
}
13681367

1369-
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
1368+
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only) {
13701369
LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
13711370

13721371
if (n_tokens % n_seqs != 0) {
@@ -1401,7 +1400,9 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
14011400
this->n_outputs = save_n_outputs;
14021401

14031402
// initialize scheduler with the specified graph
1404-
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
1403+
if (split_only) {
1404+
ggml_backend_sched_split_graph(sched.get(), gf);
1405+
} else if (!ggml_backend_sched_reserve(sched.get(), gf)) {
14051406
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
14061407
return nullptr;
14071408
}

src/llama-context.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ struct llama_context {
196196
ggml_status graph_compute(ggml_cgraph * gf, bool batched);
197197

198198
// reserve a graph with a dummy ubatch of the specified size
199-
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx);
199+
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false);
200200

201201
private:
202202
llm_graph_params graph_params(

0 commit comments

Comments
 (0)