@@ -270,19 +270,7 @@ llama_context::llama_context(
270270 }
271271 }
272272
273- // resolve automatic Flash Attention use and reserve worst-case graph
274273 if (!hparams.vocab_only ) {
275- const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max ;
276- const uint32_t n_tokens = std::min (cparams.n_ctx , cparams.n_ubatch );
277-
278- LLAMA_LOG_DEBUG (" %s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n " , __func__, n_tokens, n_seqs, n_outputs);
279-
280- int n_splits_pp = -1 ;
281- int n_nodes_pp = -1 ;
282-
283- int n_splits_tg = -1 ;
284- int n_nodes_tg = -1 ;
285-
286274 llama_memory_context_ptr mctx;
287275 if (memory) {
288276 LLAMA_LOG_DEBUG (" %s: reserving full memory module\n " , __func__);
@@ -293,6 +281,59 @@ llama_context::llama_context(
293281 }
294282
295283 cross.v_embd .clear ();
284+ // resolve automatic Flash Attention use
285+ if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
286+ auto * gf = graph_reserve (1 , 1 , 0 , mctx.get (), true );
287+ if (!gf) {
288+ throw std::runtime_error (" failed to split graph for Flash Attention check" );
289+ }
290+
291+ const size_t prefix_len = strlen (LLAMA_TENSOR_NAME_FATTN) + 1 ;
292+ bool fa_device_mismatch = false ;
293+ for (int i = 0 ; i < ggml_graph_n_nodes (gf); i++) {
294+ ggml_tensor * n = ggml_graph_node (gf, i);
295+ if (n->op != GGML_OP_FLASH_ATTN_EXT) {
296+ continue ;
297+ }
298+ ggml_backend_dev_t device_fa = ggml_backend_get_device (
299+ ggml_backend_sched_get_tensor_backend (sched.get (), n));
300+
301+ // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer
302+ GGML_ASSERT (strncmp (n->name , LLAMA_TENSOR_NAME_FATTN " -" , prefix_len) == 0 );
303+ const int il = std::stoi (n->name + prefix_len);
304+ ggml_backend_dev_t device_kv = model.dev_layer (il);
305+ if (device_fa != device_kv) {
306+ LLAMA_LOG_WARN (" %s: layer %d is assigned to device %s but the Flash Attention tensor "
307+ " is assigned to device %s (usually due to missing support)\n " ,
308+ __func__, il, ggml_backend_dev_name (device_kv), ggml_backend_dev_name (device_fa));
309+ // FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways
310+ fa_device_mismatch = true ;
311+ break ;
312+ }
313+ }
314+ if (fa_device_mismatch) {
315+ cparams.flash_attn = false ;
316+ LLAMA_LOG_WARN (" %s: Flash Attention was auto, set to disabled\n " , __func__);
317+ if (ggml_is_quantized (params.type_v )) {
318+ throw std::runtime_error (" quantized V cache was requested, but this requires Flash Attention" );
319+ }
320+ } else {
321+ cparams.flash_attn = true ;
322+ LLAMA_LOG_INFO (" %s: Flash Attention was auto, set to enabled\n " , __func__);
323+ }
324+ }
325+
326+ // reserve worst-case graph
327+ const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max ;
328+ const uint32_t n_tokens = std::min (cparams.n_ctx , cparams.n_ubatch );
329+
330+ LLAMA_LOG_DEBUG (" %s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n " , __func__, n_tokens, n_seqs, n_outputs);
331+
332+ int n_splits_pp = -1 ;
333+ int n_nodes_pp = -1 ;
334+
335+ int n_splits_tg = -1 ;
336+ int n_nodes_tg = -1 ;
296337
297338 // reserve pp (prompt processing) graph first so that buffers are only allocated once
298339 {
@@ -301,48 +342,6 @@ llama_context::llama_context(
301342 throw std::runtime_error (" failed to allocate compute pp buffers" );
302343 }
303344
304- if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
305- ggml_backend_sched_alloc_graph (sched.get (), gf);
306-
307- const size_t prefix_len = strlen (LLAMA_TENSOR_NAME_FATTN) + 1 ;
308- bool fa_device_mismatch = false ;
309- for (int i = 0 ; i < ggml_graph_n_nodes (gf); i++) {
310- ggml_tensor * n = ggml_graph_node (gf, i);
311- if (n->op != GGML_OP_FLASH_ATTN_EXT) {
312- continue ;
313- }
314- ggml_backend_dev_t device_fa = ggml_backend_get_device (
315- ggml_backend_sched_get_tensor_backend (sched.get (), n));
316-
317- // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer
318- GGML_ASSERT (strncmp (n->name , LLAMA_TENSOR_NAME_FATTN " -" , prefix_len) == 0 );
319- const int il = std::stoi (n->name + prefix_len);
320- ggml_backend_dev_t device_kv = model.dev_layer (il);
321- if (device_fa != device_kv) {
322- LLAMA_LOG_WARN (" %s: layer %d is assigned to device %s but the Flash Attention tensor "
323- " is assigned to device %s (usually due to missing support)\n " ,
324- __func__, il, ggml_backend_dev_name (device_kv), ggml_backend_dev_name (device_fa));
325- // FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways
326- fa_device_mismatch = true ;
327- break ;
328- }
329- }
330- if (fa_device_mismatch) {
331- cparams.flash_attn = false ;
332- LLAMA_LOG_WARN (" %s: Flash Attention was auto, set to disabled\n " , __func__);
333- if (ggml_is_quantized (params.type_v )) {
334- throw std::runtime_error (" quantized V cache was requested, but this requires Flash Attention" );
335- }
336- auto * gf = graph_reserve (n_tokens, n_seqs, n_tokens, mctx.get ());
337- if (!gf) {
338- throw std::runtime_error (" failed to allocate compute pp buffers" );
339- }
340- } else {
341- cparams.flash_attn = true ;
342- LLAMA_LOG_INFO (" %s: Flash Attention was auto, set to enabled\n " , __func__);
343- }
344- }
345-
346345 n_splits_pp = ggml_backend_sched_get_n_splits (sched.get ());
347346 n_nodes_pp = ggml_graph_n_nodes (gf);
348347 }
@@ -1366,7 +1365,7 @@ llm_graph_result * llama_context::get_gf_res_reserve() const {
13661365 return static_cast <llm_graph_result *>(gf_res_reserve.get ());
13671366}
13681367
1369- ggml_cgraph * llama_context::graph_reserve (uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
1368+ ggml_cgraph * llama_context::graph_reserve (uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only ) {
13701369 LLAMA_LOG_DEBUG (" %s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n " , __func__, n_tokens, n_seqs, n_outputs);
13711370
13721371 if (n_tokens % n_seqs != 0 ) {
@@ -1401,7 +1400,9 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
14011400 this ->n_outputs = save_n_outputs;
14021401
14031402 // initialize scheduler with the specified graph
1404- if (!ggml_backend_sched_reserve (sched.get (), gf)) {
1403+ if (split_only) {
1404+ ggml_backend_sched_split_graph (sched.get (), gf);
1405+ } else if (!ggml_backend_sched_reserve (sched.get (), gf)) {
14051406 LLAMA_LOG_ERROR (" %s: failed to allocate compute buffers\n " , __func__);
14061407 return nullptr ;
14071408 }
0 commit comments