Skip to content

Conversation

JohannesGaessler
Copy link
Collaborator

This PR updates the llama.cpp defaults to use FlashAttention and the maximum number of GPU layers by default. FlashAttention is I think by now mature enough where it is the better choice for most combinations of models and hardware. Both 0 and max. GPU layers have downsides but I think that there are more cases where max. GPU layers is the better choice. In particular, when someone is using llama.cpp for the first time and most reliant on defaults, they would likely be using a very small model for testing (and in that scenario max. GPU layers in definitely the correct choice).

@github-actions github-actions bot added script Script related python python script changes labels Aug 19, 2025
@slaren
Copy link
Member

slaren commented Aug 19, 2025

I agree with increasing ngl by default, but the error message when model loading fails due to buffer allocation error should give some hint to let the users know what they need to change.

The reason FA is not already the default for the backends that support it is because the CUDA backend implementation of supports_op is not reliable. You can reproduce this with llama-cli -hf ggml-org/tiny-llamas -hff stories260K.gguf -fa -ngl 99. If you can fix that, I can revive #10101.

@github-actions github-actions bot added examples ggml changes relating to the ggml tensor library for machine learning labels Aug 21, 2025
@JohannesGaessler
Copy link
Collaborator Author

I pushed a version that seems to work for automatically setting FlashAttention (the same for all layers). The way I'm determining whether FA should be used is to check whether or not the FA ggml op is being assigned to the same backend as the previous node in the graph. But that is I think a bad solution. Would it make sense to set a flag for tensors that cannot run on fast backends?

The point at which I'm resolving -fa auto is when the worst-case graphs are reserved since FA is relevant there.

For this PR my goal is not to implement toggling FA on a per-layer basis - I'm not convinced that there are many situations where this would make sense.

Comment on lines 313 to 316
if (cparams.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
bool fa_backend_mismatch = false;
GGML_ASSERT(ggml_graph_node(gf, 0)->op != GGML_OP_FLASH_ATTN_EXT);
for (int i = 1; i < ggml_graph_n_nodes(gf); i++) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is a good way to do it. It is very fragile code that makes a lot of assumptions that are not guaranteed anywhere, and will break very easily and in a very difficult way to detect when making changes to other parts of the code.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A potentially slightly better way to do it could be:

  • Extract the layer number from the tensor name
  • Verify if the device of the backend (ggml_backend_get_device) is the same as the device assigned to the layer KV
  • The device assigned to the layer can be obtained from model.dev_layer(il) if offload_kqv, CPU otherwise

@JohannesGaessler
Copy link
Collaborator Author

Just to make sure that this doesn't go unnoticed: on master ggml_backend_sched_reserve automatically resets the backend scheduler, this PR removes this automatic reset because otherwise the tensor assignments cannot be retrieved.

@slaren
Copy link
Member

slaren commented Aug 27, 2025

Just to make sure that this doesn't go unnoticed: on master ggml_backend_sched_reserve automatically resets the backend scheduler, this PR removes this automatic reset because otherwise the tensor assignments cannot be retrieved.

You can use ggml_backend_sched_alloc_graph instead, there is no need to change ggml_backend_sched_reserve.

Copy link
Member

@slaren slaren left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also need a message telling users to reduce --n-gpu-layers when loading a model fails, otherwise people trying to run models bigger than their VRAM will just see an error and assume that they cannot use llama.cpp.

common/common.h Outdated
Comment on lines 312 to 317
#ifdef GGML_USE_WEBGPU
// FIXME the webgpu backend is lacking support for very basic operations so the test allocation for -fa auto result in an abort
enum llama_flash_attn_type flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED; // whether to use Flash Attention
#else
enum llama_flash_attn_type flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO; // whether to use Flash Attention
#endif // GGML_USE_WEBGPU
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be better to leave the test failing than adding an exception here.

common/arg.cpp Outdated
Comment on lines 3470 to 3471
params.n_gpu_layers = 999;
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since these values are the same as the default now, these lines could be removed entirely.

}

ggml_backend_dev_t ggml_backend_get_device(ggml_backend_t backend) {
GGML_ASSERT(backend);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't mind having asserts against null pointers here, but it needs to be consistent, not just in one isolated function.

@JohannesGaessler
Copy link
Collaborator Author

We also need a message telling users to reduce --n-gpu-layers when loading a model fails, otherwise people trying to run models bigger than their VRAM will just see an error and assume that they cannot use llama.cpp.

I added the messages in common.cpp because only in that context a reference to --n-gpu-layers makes sense. Or do you mean that the C API should also explicitly mention the CPU+GPU hybrid functionality?

@slaren
Copy link
Member

slaren commented Aug 29, 2025

I added the messages in common.cpp

My bad, I missed that. Looks good.

@JohannesGaessler
Copy link
Collaborator Author

For ggml-backend.cpp I added asserts to pointers that are going to be accessed unconditionally and would result in a segmentation fault.

@JohannesGaessler
Copy link
Collaborator Author

Supposedly there are issues with context shifting when using FlashAttention: #9646

This would align with the test in test_ctx_shift.py failing if FA is enabled (already happens on master, independently of this PR).

@ggerganov
Copy link
Member

Supposedly there are issues with context shifting when using FlashAttention: #9646

This would align with the test in test_ctx_shift.py failing if FA is enabled (already happens on master, independently of this PR).

Could you show a failure log or steps to reproduce?

Comment on lines 327 to 329
const int il = std::stoi(n->name + 6);
ggml_backend_dev_t device_kv = model.dev_layer(il);
if (device_fa != device_kv) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should require checking against the CPU when using no-kv-offload, but it seems to be broken at the moment, and attention ops are being run on the GPU even when not offloaded.

# 64 tokens are generated thanks to shifting the context when it gets full
global server
server.enable_ctx_shift = True
server.fa = "off" # FIXME prompt_n assert fails otherwise
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ggerganov remove this line or set it to "on", then run the unit test. Alternatively, edit the unit test on master to run with FA.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the assert that you observed too:

        assert res.status_code == 200
>       assert res.body["timings"]["prompt_n"] == 109
E       assert 173 == 109
unit/test_ctx_shift.py:36: AssertionError

FAILED unit/test_ctx_shift.py::test_ctx_shift_enabled - assert 173 == 109

This occurs because we pad the context size to 256 when flash attention is enabled:

uint32_t llama_kv_cache::get_padding(const llama_cparams & cparams) {
// the FA kernels require padding to avoid extra runtime boundary checks
return cparams.flash_attn ? 256u : 32u;
}

So in this test, when FA is off the padding is 32 and when FA is on the padding is 256. This affects the amount of truncated tokens from the prompt.

You can fix this with this patch on master to make it work both with and without FA:

diff --git a/tools/server/tests/unit/test_ctx_shift.py b/tools/server/tests/unit/test_ctx_shift.py
index 8f51bc301..3edf18727 100644
--- a/tools/server/tests/unit/test_ctx_shift.py
+++ b/tools/server/tests/unit/test_ctx_shift.py
@@ -15,25 +15,27 @@ Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deseru
 def create_server():
     global server
     server = ServerPreset.tinyllama2()
-    server.n_ctx = 256
+    server.n_ctx = 512
     server.n_slots = 2
+    server.n_predict = 128
 
 
 def test_ctx_shift_enabled():
     # the prompt is 301 tokens
-    # the slot context is 256/2 = 128 tokens
-    # the prompt is truncated to keep the last 109 tokens
-    # 64 tokens are generated thanks to shifting the context when it gets full
+    # the slot context is 512/2 = 256 tokens
+    # the prompt is truncated to keep the last (301 - 256/2) = 173 tokens
+    # 96 tokens are generated thanks to shifting the context when it gets full
     global server
     server.enable_ctx_shift = True
+    server.fa = True
     server.start()
     res = server.make_request("POST", "/completion", data={
-        "n_predict": 64,
+        "n_predict": 96,
         "prompt": LONG_TEXT,
     })
     assert res.status_code == 200
-    assert res.body["timings"]["prompt_n"] == 109
-    assert res.body["timings"]["predicted_n"] == 64
+    assert res.body["timings"]["prompt_n"] == 173
+    assert res.body["timings"]["predicted_n"] == 96
     assert res.body["truncated"] is True
 
 
diff --git a/tools/server/tests/utils.py b/tools/server/tests/utils.py
index f55a53947..d9df9bd91 100644
--- a/tools/server/tests/utils.py
+++ b/tools/server/tests/utils.py
@@ -160,7 +160,7 @@ class ServerProcess:
             server_args.extend(["-ctk", self.ctk])
         if self.ctv:
             server_args.extend(["-ctv", self.ctv])
-        if self.fa is not None:
+        if self.fa is not None and self.fa is True:
             server_args.append("-fa")
         if self.n_predict:
             server_args.extend(["--n-predict", self.n_predict])

@JohannesGaessler
Copy link
Collaborator Author

I think the problem has to do with the same graph being passed twice. If I edit the code like this

diff --git a/src/llama-context.cpp b/src/llama-context.cpp
index ac8453ab7..105b91c73 100644
--- a/src/llama-context.cpp
+++ b/src/llama-context.cpp
@@ -1405,6 +1405,10 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
         LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
         return nullptr;
     }
+    if (!ggml_backend_sched_reserve(sched.get(), gf)) {
+        LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
+        return nullptr;
+    }

     return gf;
 }
(

so that ggml_backend_sched_reserve is being called twice for the same graph I can provoke the same error.

@JohannesGaessler
Copy link
Collaborator Author

My understanding is that ggml_cgraph is intended to be allocated exactly once. However, this logic extends to ggml_backend_sched_reserve since it internally calls ggml_backend_sched_split_graph. In that function the source tensors of the passed graph are replaced, so calling with the same graph e.g. ggml_backend_sched_reserve multiple times or ggml_backend_sched_reserve followed by ggml_backend_sched_alloc_graph results in an incorrect assignment of graph inputs, which in this case triggered the assert.

Should we add something like a .used flag to ggml_cgraph to more easily detect when user code passes the same graph multiple times?

@saadsafi
Copy link

saadsafi commented Aug 31, 2025

I had this error all morning:
ggml-backend.cpp: GGML_ASSERT(n_graph_inputs < GGML_SCHED_MAX_SPLIT_INPUTS) failed
I have 2 nvidia GPUs. I build from source on ubuntu (nvidia cuda container) every morning.

I read the comments above and I can confirm the error got resolved by adding -fa off

@jacekpoplawski
Copy link
Contributor

I’m not sure which commit changed the behavior, but it looks like it works correctly now.

@jacekpoplawski
Copy link
Contributor

jacekpoplawski commented Sep 1, 2025

All models work except gemma-3n (both E4B and E2B)

llama_context: pipeline parallelism enabled (n_copies=4)
/home/jacek/git/llama.cpp/ggml/src/ggml-backend.cpp:1258: GGML_ASSERT(n_inputs < GGML_SCHED_MAX_SPLIT_INPUTS) failed

(works with -fa off or CUDA_VISIBLE_DEVICES=0)

@slaren
Copy link
Member

slaren commented Sep 1, 2025

All models work except gemma-3n (both E4B and E2B)

llama_context: pipeline parallelism enabled (n_copies=4)
/home/jacek/git/llama.cpp/ggml/src/ggml-backend.cpp:1258: GGML_ASSERT(n_inputs < GGML_SCHED_MAX_SPLIT_INPUTS) failed

(works with -fa off or CUDA_VISIBLE_DEVICES=0)

Cannot reproduce this with two GPUs.

@jacekpoplawski
Copy link
Contributor

All models work except gemma-3n (both E4B and E2B)

llama_context: pipeline parallelism enabled (n_copies=4)
/home/jacek/git/llama.cpp/ggml/src/ggml-backend.cpp:1258: GGML_ASSERT(n_inputs < GGML_SCHED_MAX_SPLIT_INPUTS) failed

(works with -fa off or CUDA_VISIBLE_DEVICES=0)

Cannot reproduce this with two GPUs.

you are right, "CUDA_VISIBLE_DEVICES=0,1" also fixes the issue, so 3 GPUs are needed

I can debug or send more logs if needed

call stack:

#4  0x000071dc27b77f13 in ggml_print_backtrace () from /home/jacek/git/llama.cpp/build_2025.08.31/bin/libggml-base.so
#5  0x000071dc27b780bb in ggml_abort () from /home/jacek/git/llama.cpp/build_2025.08.31/bin/libggml-base.so
#6  0x000071dc27b91126 in ggml_backend_sched_split_graph () from /home/jacek/git/llama.cpp/build_2025.08.31/bin/libggml-base.so
#7  0x000071dc27c9c752 in llama_context::graph_reserve(unsigned int, unsigned int, unsigned int, llama_memory_context_i const*, bool) () from /home/jacek/git/llama.cpp/build_2025.08.31/bin/libllama.so
#8  0x000071dc27c9f9aa in llama_context::llama_context(llama_model const&, llama_context_params) () from /home/jacek/git/llama.cpp/build_2025.08.31/bin/libllama.so
#9  0x000071dc27ca0146 in llama_init_from_model () from /home/jacek/git/llama.cpp/build_2025.08.31/bin/libllama.so
#10 0x0000587795327186 in common_init_from_params(common_params&) ()
#11 0x000058779521085f in server_context::load_model(common_params const&) ()
#12 0x00005877951a560b in main ()

@slaren
Copy link
Member

slaren commented Sep 1, 2025

@JohannesGaessler I cannot test this easily, but please increase GGML_SCHED_MAX_SPLIT_INPUTS as much as needed to fix this.

@jacekpoplawski
Copy link
Contributor

I tried this:

#define GGML_MAX_SRC            50

and this:

GGML_LOG_ERROR("n_inputs: %d GGML_SCHED_MAX_SPLIT_INPUT: %d\n", n_inputs, GGML_SCHED_MAX_SPLIT_INPUTS);

on google_gemma-3-27b-it-Q8_0.gguf max value is:

n_inputs: 8 GGML_SCHED_MAX_SPLIT_INPUT: 50

but on google_gemma-3n-E4B-it-Q8_0.gguf:

n_inputs: 26 GGML_SCHED_MAX_SPLIT_INPUT: 50

@Thireus
Copy link

Thireus commented Sep 2, 2025

@ubergarm, llama-sweep-bench no longer compiles because common_params definition no longer includes flash_attn.

In file included from /home/runner/work/llama.cpp/llama.cpp/examples/sweep-bench/sweep-bench.cpp:7:
/home/runner/work/llama.cpp/llama.cpp/examples/sweep-bench/sweep-bench.cpp: In function ‘int main(int, char**)’:
/home/runner/work/llama.cpp/llama.cpp/examples/sweep-bench/sweep-bench.cpp:146:203: error: ‘struct common_params’ has no member named ‘flash_attn’; did you mean ‘flash_attn_type’?
  146 |         LOG_INF("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, flash_attn = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, params.n_batch, params.n_ubatch, params.flash_attn, params.n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch);
      |                                                                                                                                                                                                           ^~~~~~~~~~
/home/runner/work/llama.cpp/llama.cpp/common/./log.h:86:56: note: in definition of macro ‘LOG_TMPL’
   86 |             common_log_add(common_log_main(), (level), __VA_ARGS__); \
      |                                                        ^~~~~~~~~~~
/home/runner/work/llama.cpp/llama.cpp/examples/sweep-bench/sweep-bench.cpp:146:9: note: in expansion of macro ‘LOG_INF’
  146 |         LOG_INF("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, flash_attn = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, params.n_batch, params.n_ubatch, params.flash_attn, params.n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch);
      |         ^~~~~~~
/home/runner/work/llama.cpp/llama.cpp/examples/sweep-bench/sweep-bench.cpp:248:67: error: ‘struct common_params’ has no member named ‘flash_attn’; did you mean ‘flash_attn_type’?
  248 |                 n_kv_max, params.n_batch, params.n_ubatch, params.flash_attn, params.n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch,
      |                                                                   ^~~~~~~~~~
/home/runner/work/llama.cpp/llama.cpp/common/./log.h:86:56: note: in definition of macro ‘LOG_TMPL’
   86 |             common_log_add(common_log_main(), (level), __VA_ARGS__); \
      |                                                        ^~~~~~~~~~~
/home/runner/work/llama.cpp/llama.cpp/examples/sweep-bench/sweep-bench.cpp:245:13: note: in expansion of macro ‘LOG_INF’
  245 |             LOG_INF(
      |             ^~~~~~~

https://github.com/ggml-org/llama.cpp/pull/15434/files#diff-34c932128256ee886b3a8581b5f11a1c38717aaa9d228189f1ce12e823f3207fL375

ubergarm added a commit to ubergarm/llama.cpp that referenced this pull request Sep 2, 2025
Behavior of mainline llama.cpp `-fa` changed and now *requires* an
argument of `on` or `1` it seems to enable flash attenion explicitly.
This diverges from ik_llama.cpp behavior which omitting it is disabled,
however on mainline that means `auto` which means "probably enabled" I
believe.

Details here: ggml-org#15434

This patch just changes all `s/flash_attn/flash_attn_type/g`.
ubergarm added a commit to ubergarm/llama.cpp that referenced this pull request Sep 5, 2025
Behavior of mainline llama.cpp `-fa` changed and now *requires* an
argument of `on` or `1` it seems to enable flash attenion explicitly.
This diverges from ik_llama.cpp behavior which omitting it is disabled,
however on mainline that means `auto` which means "probably enabled" I
believe.

Details here: ggml-org#15434

This patch just changes all `s/flash_attn/flash_attn_type/g`.
walidbr pushed a commit to walidbr/llama.cpp that referenced this pull request Sep 7, 2025
* llama: use max. GPU layers by default, auto -fa

* ggml-backend: abort instead of segfault
ubergarm added a commit to ubergarm/llama.cpp that referenced this pull request Sep 11, 2025
Behavior of mainline llama.cpp `-fa` changed and now *requires* an
argument of `on` or `1` it seems to enable flash attenion explicitly.
This diverges from ik_llama.cpp behavior which omitting it is disabled,
however on mainline that means `auto` which means "probably enabled" I
believe.

Details here: ggml-org#15434

This patch just changes all `s/flash_attn/flash_attn_type/g`.
ubergarm added a commit to ubergarm/llama.cpp that referenced this pull request Sep 22, 2025
Behavior of mainline llama.cpp `-fa` changed and now *requires* an
argument of `on` or `1` it seems to enable flash attenion explicitly.
This diverges from ik_llama.cpp behavior which omitting it is disabled,
however on mainline that means `auto` which means "probably enabled" I
believe.

Details here: ggml-org#15434

This patch just changes all `s/flash_attn/flash_attn_type/g`.
AndrewMobbs added a commit to AndrewMobbs/llauncher that referenced this pull request Sep 23, 2025
ubergarm added a commit to ubergarm/llama.cpp that referenced this pull request Sep 26, 2025
Behavior of mainline llama.cpp `-fa` changed and now *requires* an
argument of `on` or `1` it seems to enable flash attenion explicitly.
This diverges from ik_llama.cpp behavior which omitting it is disabled,
however on mainline that means `auto` which means "probably enabled" I
believe.

Details here: ggml-org#15434

This patch just changes all `s/flash_attn/flash_attn_type/g`.
ubergarm added a commit to ubergarm/llama.cpp that referenced this pull request Sep 28, 2025
Behavior of mainline llama.cpp `-fa` changed and now *requires* an
argument of `on` or `1` it seems to enable flash attenion explicitly.
This diverges from ik_llama.cpp behavior which omitting it is disabled,
however on mainline that means `auto` which means "probably enabled" I
believe.

Details here: ggml-org#15434

This patch just changes all `s/flash_attn/flash_attn_type/g`.
ubergarm added a commit to ubergarm/llama.cpp that referenced this pull request Oct 1, 2025
Behavior of mainline llama.cpp `-fa` changed and now *requires* an
argument of `on` or `1` it seems to enable flash attenion explicitly.
This diverges from ik_llama.cpp behavior which omitting it is disabled,
however on mainline that means `auto` which means "probably enabled" I
believe.

Details here: ggml-org#15434

This patch just changes all `s/flash_attn/flash_attn_type/g`.
Nexesenex added a commit to Nexesenex/croco.cpp that referenced this pull request Oct 6, 2025
Nexesenex added a commit to Nexesenex/croco.cpp that referenced this pull request Oct 6, 2025
ubergarm added a commit to ubergarm/llama.cpp that referenced this pull request Oct 13, 2025
Behavior of mainline llama.cpp `-fa` changed and now *requires* an
argument of `on` or `1` it seems to enable flash attenion explicitly.
This diverges from ik_llama.cpp behavior which omitting it is disabled,
however on mainline that means `auto` which means "probably enabled" I
believe.

Details here: ggml-org#15434

This patch just changes all `s/flash_attn/flash_attn_type/g`.
@ubergarm
Copy link

lol so sorry i ever put a link or usernames in my commit message, i'll try to get rid of that spam... 💀

ubergarm added a commit to ubergarm/llama.cpp that referenced this pull request Oct 13, 2025
Behavior of mainline llama.cpp `-fa` changed and now *requires* an
argument of `on` or `1` it seems to enable flash attenion explicitly.
This diverges from ik_llama.cpp behavior which omitting it is disabled,
however on mainline that means `auto` which means "probably enabled" I
believe.

Details here: `github.com /ggml-org/pull/15434`

This patch just changes all `s/flash_attn/flash_attn_type/g`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

examples ggml changes relating to the ggml tensor library for machine learning python python script changes script Script related server

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants