@@ -41,7 +41,6 @@ llama_context::llama_context(
4141 cparams.yarn_beta_slow = params.yarn_beta_slow ;
4242 cparams.embeddings = params.embeddings ;
4343 cparams.offload_kqv = params.offload_kqv ;
44- cparams.flash_attn = params.flash_attn ;
4544 cparams.no_perf = params.no_perf ;
4645 cparams.pooling_type = params.pooling_type ;
4746 cparams.warmup = false ;
@@ -86,6 +85,8 @@ llama_context::llama_context(
8685 cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
8786 }
8887
88+ cparams.flash_attn = params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED;
89+
8990 // with causal attention, the batch size is limited by the context size
9091 cparams.n_batch = cparams.causal_attn ? std::min (cparams.n_ctx , params.n_batch ) : params.n_batch ;
9192
@@ -129,7 +130,7 @@ llama_context::llama_context(
129130 LLAMA_LOG_INFO (" %s: n_batch = %u\n " , __func__, cparams.n_batch );
130131 LLAMA_LOG_INFO (" %s: n_ubatch = %u\n " , __func__, cparams.n_ubatch );
131132 LLAMA_LOG_INFO (" %s: causal_attn = %d\n " , __func__, cparams.causal_attn );
132- LLAMA_LOG_INFO (" %s: flash_attn = %d \n " , __func__, cparams. flash_attn );
133+ LLAMA_LOG_INFO (" %s: flash_attn = %s \n " , __func__, llama_flash_attn_type_name (params. flash_attn_type ) );
133134 LLAMA_LOG_INFO (" %s: kv_unified = %s\n " , __func__, cparams.kv_unified ? " true" : " false" );
134135 LLAMA_LOG_INFO (" %s: freq_base = %.1f\n " , __func__, cparams.rope_freq_base );
135136 LLAMA_LOG_INFO (" %s: freq_scale = %g\n " , __func__, cparams.rope_freq_scale );
@@ -279,7 +280,7 @@ llama_context::llama_context(
279280 }
280281 }
281282
282- // reserve worst-case graph
283+ // resolve automatic Flash Attention use and reserve worst-case graph
283284 if (!hparams.vocab_only ) {
284285 const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max ;
285286 const uint32_t n_tokens = std::min (cparams.n_ctx , cparams.n_ubatch );
@@ -310,6 +311,42 @@ llama_context::llama_context(
310311 throw std::runtime_error (" failed to allocate compute pp buffers" );
311312 }
312313
314+ if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
315+ ggml_backend_sched_alloc_graph (sched.get (), gf);
316+
317+ bool fa_device_mismatch = false ;
318+ for (int i = 0 ; i < ggml_graph_n_nodes (gf); i++) {
319+ ggml_tensor * n = ggml_graph_node (gf, i);
320+ if (n->op != GGML_OP_FLASH_ATTN_EXT) {
321+ continue ;
322+ }
323+ ggml_backend_dev_t device_fa = ggml_backend_get_device (
324+ ggml_backend_sched_get_tensor_backend (sched.get (), n));
325+
326+ GGML_ASSERT (strncmp (n->name , " fattn-" , 6 ) == 0 );
327+ const int il = std::stoi (n->name + 6 );
328+ ggml_backend_dev_t device_kv = model.dev_layer (il);
329+ if (device_fa != device_kv) {
330+ fa_device_mismatch = true ;
331+ break ;
332+ }
333+ }
334+ if (fa_device_mismatch) {
335+ cparams.flash_attn = false ;
336+ LLAMA_LOG_INFO (" %s: Flash Attention was auto, set to disabled\n " , __func__);
337+ if (ggml_is_quantized (params.type_v )) {
338+ throw std::runtime_error (" quantized V cache was requested, but this requires Flash Attention" );
339+ }
340+ auto * gf = graph_reserve (n_tokens, n_seqs, n_tokens, mctx.get ());
341+ if (!gf) {
342+ throw std::runtime_error (" failed to allocate compute pp buffers" );
343+ }
344+ } else {
345+ cparams.flash_attn = true ;
346+ LLAMA_LOG_INFO (" %s: Flash Attention was auto, set to enabled\n " , __func__);
347+ }
348+ }
349+
313350 n_splits_pp = ggml_backend_sched_get_n_splits (sched.get ());
314351 n_nodes_pp = ggml_graph_n_nodes (gf);
315352 }
@@ -2230,6 +2267,7 @@ llama_context_params llama_context_default_params() {
22302267 /* .rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
22312268 /* .pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
22322269 /* .attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED,
2270+ /* .flash_attn_type =*/ LLAMA_FLASH_ATTN_TYPE_AUTO,
22332271 /* .rope_freq_base =*/ 0 .0f ,
22342272 /* .rope_freq_scale =*/ 0 .0f ,
22352273 /* .yarn_ext_factor =*/ -1 .0f ,
@@ -2246,7 +2284,6 @@ llama_context_params llama_context_default_params() {
22462284 /* .abort_callback_data =*/ nullptr ,
22472285 /* .embeddings =*/ false ,
22482286 /* .offload_kqv =*/ true ,
2249- /* .flash_attn =*/ false ,
22502287 /* .no_perf =*/ true ,
22512288 /* .op_offload =*/ true ,
22522289 /* .swa_full =*/ true ,
@@ -2274,12 +2311,30 @@ llama_context * llama_init_from_model(
22742311 return nullptr ;
22752312 }
22762313
2277- if (params.flash_attn && model->arch == LLM_ARCH_GROK) {
2314+ if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && model->arch == LLM_ARCH_GROK) {
22782315 LLAMA_LOG_WARN (" %s: flash_attn is not compatible with Grok - forcing off\n " , __func__);
2279- params.flash_attn = false ;
2316+ params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED;
2317+ }
2318+
2319+ if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized (params.type_k )) {
2320+ const uint32_t blck_size = ggml_blck_size (params.type_k );
2321+ if (model->hparams .n_embd_head_k % blck_size != 0 ) {
2322+ LLAMA_LOG_ERROR (" %s: K cache type %s with block size %u does not divide n_embd_head_k=%u\n " ,
2323+ __func__, ggml_type_name (params.type_k ), blck_size, model->hparams .n_embd_head_k );
2324+ return nullptr ;
2325+ }
2326+ }
2327+
2328+ if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized (params.type_v )) {
2329+ const uint32_t blck_size = ggml_blck_size (params.type_v );
2330+ if (model->hparams .n_embd_head_v % blck_size != 0 ) {
2331+ LLAMA_LOG_ERROR (" %s: V cache type %s with block size %u does not divide n_embd_head_k=%u\n " ,
2332+ __func__, ggml_type_name (params.type_v ), blck_size, model->hparams .n_embd_head_v );
2333+ return nullptr ;
2334+ }
22802335 }
22812336
2282- if (ggml_is_quantized (params.type_v ) && ! params.flash_attn ) {
2337+ if (ggml_is_quantized (params.type_v ) && params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_DISABLED ) {
22832338 LLAMA_LOG_ERROR (" %s: V cache quantization requires flash_attn\n " , __func__);
22842339 return nullptr ;
22852340 }
0 commit comments