Skip to content

Commit 86f0cea

Browse files
llama: use max. GPU layers by default, auto -fa
1 parent da54f9f commit 86f0cea

File tree

14 files changed

+149
-55
lines changed

14 files changed

+149
-55
lines changed

common/arg.cpp

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1545,10 +1545,18 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
15451545
}
15461546
).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_RETRIEVAL}));
15471547
add_opt(common_arg(
1548-
{"-fa", "--flash-attn"},
1549-
string_format("enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled"),
1550-
[](common_params & params) {
1551-
params.flash_attn = true;
1548+
{"-fa", "--flash-attn"}, "FA",
1549+
string_format("set Flash Attention use ('on', 'off', or 'auto', default: '%s')", llama_flash_attn_type_name(params.flash_attn_type)),
1550+
[](common_params & params, const std::string & value) {
1551+
if (value == "on" || value == "enabled") {
1552+
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED;
1553+
} else if (value == "off" || value == "disabled") {
1554+
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED;
1555+
} else if (value == "auto") {
1556+
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO;
1557+
} else {
1558+
throw std::runtime_error(string_format("error: unkown value for --flash-attn: '%s'\n", value.c_str()));
1559+
}
15521560
}
15531561
).set_env("LLAMA_ARG_FLASH_ATTN"));
15541562
add_opt(common_arg(
@@ -3459,8 +3467,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
34593467
params.model.hf_repo = "ggml-org/Qwen2.5-Coder-1.5B-Q8_0-GGUF";
34603468
params.model.hf_file = "qwen2.5-coder-1.5b-q8_0.gguf";
34613469
params.port = 8012;
3462-
params.n_gpu_layers = 99;
3463-
params.flash_attn = true;
3470+
params.n_gpu_layers = 999;
3471+
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED;
34643472
params.n_ubatch = 1024;
34653473
params.n_batch = 1024;
34663474
params.n_ctx = 0;
@@ -3475,8 +3483,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
34753483
params.model.hf_repo = "ggml-org/Qwen2.5-Coder-3B-Q8_0-GGUF";
34763484
params.model.hf_file = "qwen2.5-coder-3b-q8_0.gguf";
34773485
params.port = 8012;
3478-
params.n_gpu_layers = 99;
3479-
params.flash_attn = true;
3486+
params.n_gpu_layers = 999;
3487+
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED;
34803488
params.n_ubatch = 1024;
34813489
params.n_batch = 1024;
34823490
params.n_ctx = 0;
@@ -3491,8 +3499,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
34913499
params.model.hf_repo = "ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF";
34923500
params.model.hf_file = "qwen2.5-coder-7b-q8_0.gguf";
34933501
params.port = 8012;
3494-
params.n_gpu_layers = 99;
3495-
params.flash_attn = true;
3502+
params.n_gpu_layers = 999;
34963503
params.n_ubatch = 1024;
34973504
params.n_batch = 1024;
34983505
params.n_ctx = 0;
@@ -3508,10 +3515,10 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
35083515
params.model.hf_file = "qwen2.5-coder-7b-q8_0.gguf";
35093516
params.speculative.model.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
35103517
params.speculative.model.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
3511-
params.speculative.n_gpu_layers = 99;
3518+
params.speculative.n_gpu_layers = 999;
35123519
params.port = 8012;
3513-
params.n_gpu_layers = 99;
3514-
params.flash_attn = true;
3520+
params.n_gpu_layers = 999;
3521+
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED;
35153522
params.n_ubatch = 1024;
35163523
params.n_batch = 1024;
35173524
params.n_ctx = 0;
@@ -3527,10 +3534,10 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
35273534
params.model.hf_file = "qwen2.5-coder-14b-q8_0.gguf";
35283535
params.speculative.model.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
35293536
params.speculative.model.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
3530-
params.speculative.n_gpu_layers = 99;
3537+
params.speculative.n_gpu_layers = 999;
35313538
params.port = 8012;
3532-
params.n_gpu_layers = 99;
3533-
params.flash_attn = true;
3539+
params.n_gpu_layers = 999;
3540+
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED;
35343541
params.n_ubatch = 1024;
35353542
params.n_batch = 1024;
35363543
params.n_ctx = 0;
@@ -3546,7 +3553,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
35463553
params.model.hf_file = "qwen3-coder-30b-a3b-instruct-q8_0.gguf";
35473554
params.port = 8012;
35483555
params.n_gpu_layers = 99;
3549-
params.flash_attn = true;
3556+
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED;
35503557
params.n_ubatch = 1024;
35513558
params.n_batch = 1024;
35523559
params.n_ctx = 0;

common/common.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -901,7 +901,8 @@ struct common_init_result common_init_from_params(common_params & params) {
901901

902902
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
903903
if (model == NULL) {
904-
LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.path.c_str());
904+
LOG_ERR("%s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
905+
__func__, params.model.path.c_str());
905906
return iparams;
906907
}
907908

@@ -911,7 +912,8 @@ struct common_init_result common_init_from_params(common_params & params) {
911912

912913
llama_context * lctx = llama_init_from_model(model, cparams);
913914
if (lctx == NULL) {
914-
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
915+
LOG_ERR("%s: failed to create context with model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
916+
__func__, params.model.path.c_str());
915917
llama_model_free(model);
916918
return iparams;
917919
}
@@ -1152,10 +1154,10 @@ struct llama_context_params common_context_params_to_llama(const common_params &
11521154
cparams.yarn_orig_ctx = params.yarn_orig_ctx;
11531155
cparams.pooling_type = params.pooling_type;
11541156
cparams.attention_type = params.attention_type;
1157+
cparams.flash_attn_type = params.flash_attn_type;
11551158
cparams.cb_eval = params.cb_eval;
11561159
cparams.cb_eval_user_data = params.cb_eval_user_data;
11571160
cparams.offload_kqv = !params.no_kv_offload;
1158-
cparams.flash_attn = params.flash_attn;
11591161
cparams.no_perf = params.no_perf;
11601162
cparams.op_offload = !params.no_op_offload;
11611163
cparams.swa_full = params.swa_full;

common/common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,7 @@ struct common_params {
309309
enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
310310
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
311311
enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings
312+
enum llama_flash_attn_type flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO; // whether to use Flash Attention
312313

313314
struct common_params_sampling sampling;
314315
struct common_params_speculative speculative;
@@ -372,7 +373,6 @@ struct common_params {
372373
bool multiline_input = false; // reverse the usage of `\`
373374
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
374375
bool cont_batching = true; // insert new sequences for decoding on-the-fly
375-
bool flash_attn = false; // flash attention
376376
bool no_perf = false; // disable performance metrics
377377
bool ctx_shift = false; // context shift on infinite text generation
378378
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)

examples/diffusion/diffusion-cli.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -564,7 +564,7 @@ int main(int argc, char ** argv) {
564564
ctx_params.n_ctx = params.n_ctx;
565565
ctx_params.n_batch = params.n_batch;
566566
ctx_params.n_ubatch = params.n_ubatch;
567-
ctx_params.flash_attn = params.flash_attn;
567+
ctx_params.flash_attn_type = params.flash_attn_type;
568568
ctx_params.no_perf = params.no_perf;
569569
ctx_params.type_k = params.cache_type_k;
570570
ctx_params.type_v = params.cache_type_v;

ggml/src/ggml-backend.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,7 @@ bool ggml_backend_offload_op(ggml_backend_t backend, const struct ggml_tensor *
346346
}
347347

348348
ggml_backend_dev_t ggml_backend_get_device(ggml_backend_t backend) {
349+
GGML_ASSERT(backend);
349350
return backend->device;
350351
}
351352

include/llama.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,14 @@ extern "C" {
179179
LLAMA_ATTENTION_TYPE_NON_CAUSAL = 1,
180180
};
181181

182+
enum llama_flash_attn_type {
183+
LLAMA_FLASH_ATTN_TYPE_AUTO = -1,
184+
LLAMA_FLASH_ATTN_TYPE_DISABLED = 0,
185+
LLAMA_FLASH_ATTN_TYPE_ENABLED = 1,
186+
};
187+
188+
LLAMA_API const char * llama_flash_attn_type_name(enum llama_flash_attn_type flash_attn_type);
189+
182190
enum llama_split_mode {
183191
LLAMA_SPLIT_MODE_NONE = 0, // single GPU
184192
LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs
@@ -303,6 +311,7 @@ extern "C" {
303311
enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
304312
enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
305313
enum llama_attention_type attention_type; // attention type to use for embeddings
314+
enum llama_flash_attn_type flash_attn_type; // when to enable Flash Attention
306315

307316
// ref: https://github.com/ggml-org/llama.cpp/pull/2054
308317
float rope_freq_base; // RoPE base frequency, 0 = from model
@@ -329,7 +338,6 @@ extern "C" {
329338
// Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
330339
bool embeddings; // if true, extract embeddings (together with logits)
331340
bool offload_kqv; // offload the KQV ops (including the KV cache) to GPU
332-
bool flash_attn; // use flash attention [EXPERIMENTAL]
333341
bool no_perf; // measure performance timings
334342
bool op_offload; // offload host tensor operations to device
335343
bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)

scripts/server-bench.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -151,12 +151,6 @@ def benchmark(
151151
if os.environ.get("LLAMA_ARG_N_PARALLEL") is None:
152152
logger.info("LLAMA_ARG_N_PARALLEL not explicitly set, using 32")
153153
os.environ["LLAMA_ARG_N_PARALLEL"] = "32"
154-
if not external_server and os.environ.get("LLAMA_ARG_N_GPU_LAYERS") is None:
155-
logger.info("LLAMA_ARG_N_GPU_LAYERS not explicitly set, using 999")
156-
os.environ["LLAMA_ARG_N_GPU_LAYERS"] = "999"
157-
if not external_server and os.environ.get("LLAMA_ARG_FLASH_ATTN") is None:
158-
logger.info("LLAMA_ARG_FLASH_ATTN not explicitly set, using 'true'")
159-
os.environ["LLAMA_ARG_FLASH_ATTN"] = "true"
160154

161155
parallel: int = int(os.environ.get("LLAMA_ARG_N_PARALLEL")) # type: ignore
162156
prompts: Union[None, list[str], list[list[int]]] = get_prompts_text(prompt_source, n_prompts)

src/llama-context.cpp

Lines changed: 62 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ llama_context::llama_context(
4141
cparams.yarn_beta_slow = params.yarn_beta_slow;
4242
cparams.embeddings = params.embeddings;
4343
cparams.offload_kqv = params.offload_kqv;
44-
cparams.flash_attn = params.flash_attn;
4544
cparams.no_perf = params.no_perf;
4645
cparams.pooling_type = params.pooling_type;
4746
cparams.warmup = false;
@@ -86,6 +85,8 @@ llama_context::llama_context(
8685
cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
8786
}
8887

88+
cparams.flash_attn = params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED;
89+
8990
// with causal attention, the batch size is limited by the context size
9091
cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
9192

@@ -129,7 +130,7 @@ llama_context::llama_context(
129130
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
130131
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
131132
LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
132-
LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
133+
LLAMA_LOG_INFO("%s: flash_attn = %s\n", __func__, llama_flash_attn_type_name(params.flash_attn_type));
133134
LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false");
134135
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
135136
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
@@ -279,7 +280,7 @@ llama_context::llama_context(
279280
}
280281
}
281282

282-
// reserve worst-case graph
283+
// resolve automatic Flash Attention use and reserve worst-case graph
283284
if (!hparams.vocab_only) {
284285
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
285286
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
@@ -310,6 +311,42 @@ llama_context::llama_context(
310311
throw std::runtime_error("failed to allocate compute pp buffers");
311312
}
312313

314+
if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
315+
ggml_backend_sched_alloc_graph(sched.get(), gf);
316+
317+
bool fa_device_mismatch = false;
318+
for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
319+
ggml_tensor * n = ggml_graph_node(gf, i);
320+
if (n->op != GGML_OP_FLASH_ATTN_EXT) {
321+
continue;
322+
}
323+
ggml_backend_dev_t device_fa = ggml_backend_get_device(
324+
ggml_backend_sched_get_tensor_backend(sched.get(), n));
325+
326+
GGML_ASSERT(strncmp(n->name, "fattn-", 6) == 0);
327+
const int il = std::stoi(n->name + 6);
328+
ggml_backend_dev_t device_kv = model.dev_layer(il);
329+
if (device_fa != device_kv) {
330+
fa_device_mismatch = true;
331+
break;
332+
}
333+
}
334+
if (fa_device_mismatch) {
335+
cparams.flash_attn = false;
336+
LLAMA_LOG_INFO("%s: Flash Attention was auto, set to disabled\n", __func__);
337+
if (ggml_is_quantized(params.type_v)) {
338+
throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention");
339+
}
340+
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
341+
if (!gf) {
342+
throw std::runtime_error("failed to allocate compute pp buffers");
343+
}
344+
} else {
345+
cparams.flash_attn = true;
346+
LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__);
347+
}
348+
}
349+
313350
n_splits_pp = ggml_backend_sched_get_n_splits(sched.get());
314351
n_nodes_pp = ggml_graph_n_nodes(gf);
315352
}
@@ -2230,6 +2267,7 @@ llama_context_params llama_context_default_params() {
22302267
/*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
22312268
/*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
22322269
/*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED,
2270+
/*.flash_attn_type =*/ LLAMA_FLASH_ATTN_TYPE_AUTO,
22332271
/*.rope_freq_base =*/ 0.0f,
22342272
/*.rope_freq_scale =*/ 0.0f,
22352273
/*.yarn_ext_factor =*/ -1.0f,
@@ -2246,7 +2284,6 @@ llama_context_params llama_context_default_params() {
22462284
/*.abort_callback_data =*/ nullptr,
22472285
/*.embeddings =*/ false,
22482286
/*.offload_kqv =*/ true,
2249-
/*.flash_attn =*/ false,
22502287
/*.no_perf =*/ true,
22512288
/*.op_offload =*/ true,
22522289
/*.swa_full =*/ true,
@@ -2274,12 +2311,30 @@ llama_context * llama_init_from_model(
22742311
return nullptr;
22752312
}
22762313

2277-
if (params.flash_attn && model->arch == LLM_ARCH_GROK) {
2314+
if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && model->arch == LLM_ARCH_GROK) {
22782315
LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__);
2279-
params.flash_attn = false;
2316+
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED;
2317+
}
2318+
2319+
if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_k)) {
2320+
const uint32_t blck_size = ggml_blck_size(params.type_k);
2321+
if (model->hparams.n_embd_head_k % blck_size != 0) {
2322+
LLAMA_LOG_ERROR("%s: K cache type %s with block size %u does not divide n_embd_head_k=%u\n",
2323+
__func__, ggml_type_name(params.type_k), blck_size, model->hparams.n_embd_head_k);
2324+
return nullptr;
2325+
}
2326+
}
2327+
2328+
if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_v)) {
2329+
const uint32_t blck_size = ggml_blck_size(params.type_v);
2330+
if (model->hparams.n_embd_head_v % blck_size != 0) {
2331+
LLAMA_LOG_ERROR("%s: V cache type %s with block size %u does not divide n_embd_head_k=%u\n",
2332+
__func__, ggml_type_name(params.type_v), blck_size, model->hparams.n_embd_head_v);
2333+
return nullptr;
2334+
}
22802335
}
22812336

2282-
if (ggml_is_quantized(params.type_v) && !params.flash_attn) {
2337+
if (ggml_is_quantized(params.type_v) && params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_DISABLED) {
22832338
LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
22842339
return nullptr;
22852340
}

0 commit comments

Comments
 (0)