@@ -270,19 +270,7 @@ llama_context::llama_context(
270
270
}
271
271
}
272
272
273
- // resolve automatic Flash Attention use and reserve worst-case graph
274
273
if (!hparams.vocab_only ) {
275
- const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max ;
276
- const uint32_t n_tokens = std::min (cparams.n_ctx , cparams.n_ubatch );
277
-
278
- LLAMA_LOG_DEBUG (" %s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n " , __func__, n_tokens, n_seqs, n_outputs);
279
-
280
- int n_splits_pp = -1 ;
281
- int n_nodes_pp = -1 ;
282
-
283
- int n_splits_tg = -1 ;
284
- int n_nodes_tg = -1 ;
285
-
286
274
llama_memory_context_ptr mctx;
287
275
if (memory) {
288
276
LLAMA_LOG_DEBUG (" %s: reserving full memory module\n " , __func__);
@@ -293,6 +281,59 @@ llama_context::llama_context(
293
281
}
294
282
295
283
cross.v_embd .clear ();
284
+ // resolve automatic Flash Attention use
285
+ if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
286
+ auto * gf = graph_reserve (1 , 1 , 0 , mctx.get (), true );
287
+ if (!gf) {
288
+ throw std::runtime_error (" failed to split graph for Flash Attention check" );
289
+ }
290
+
291
+ const size_t prefix_len = strlen (LLAMA_TENSOR_NAME_FATTN) + 1 ;
292
+ bool fa_device_mismatch = false ;
293
+ for (int i = 0 ; i < ggml_graph_n_nodes (gf); i++) {
294
+ ggml_tensor * n = ggml_graph_node (gf, i);
295
+ if (n->op != GGML_OP_FLASH_ATTN_EXT) {
296
+ continue ;
297
+ }
298
+ ggml_backend_dev_t device_fa = ggml_backend_get_device (
299
+ ggml_backend_sched_get_tensor_backend (sched.get (), n));
300
+
301
+ // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer
302
+ GGML_ASSERT (strncmp (n->name , LLAMA_TENSOR_NAME_FATTN " -" , prefix_len) == 0 );
303
+ const int il = std::stoi (n->name + prefix_len);
304
+ ggml_backend_dev_t device_kv = model.dev_layer (il);
305
+ if (device_fa != device_kv) {
306
+ LLAMA_LOG_WARN (" %s: layer %d is assigned to device %s but the Flash Attention tensor "
307
+ " is assigned to device %s (usually due to missing support)\n " ,
308
+ __func__, il, ggml_backend_dev_name (device_kv), ggml_backend_dev_name (device_fa));
309
+ // FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways
310
+ fa_device_mismatch = true ;
311
+ break ;
312
+ }
313
+ }
314
+ if (fa_device_mismatch) {
315
+ cparams.flash_attn = false ;
316
+ LLAMA_LOG_WARN (" %s: Flash Attention was auto, set to disabled\n " , __func__);
317
+ if (ggml_is_quantized (params.type_v )) {
318
+ throw std::runtime_error (" quantized V cache was requested, but this requires Flash Attention" );
319
+ }
320
+ } else {
321
+ cparams.flash_attn = true ;
322
+ LLAMA_LOG_INFO (" %s: Flash Attention was auto, set to enabled\n " , __func__);
323
+ }
324
+ }
325
+
326
+ // reserve worst-case graph
327
+ const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max ;
328
+ const uint32_t n_tokens = std::min (cparams.n_ctx , cparams.n_ubatch );
329
+
330
+ LLAMA_LOG_DEBUG (" %s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n " , __func__, n_tokens, n_seqs, n_outputs);
331
+
332
+ int n_splits_pp = -1 ;
333
+ int n_nodes_pp = -1 ;
334
+
335
+ int n_splits_tg = -1 ;
336
+ int n_nodes_tg = -1 ;
296
337
297
338
// reserve pp (prompt processing) graph first so that buffers are only allocated once
298
339
{
@@ -301,48 +342,6 @@ llama_context::llama_context(
301
342
throw std::runtime_error (" failed to allocate compute pp buffers" );
302
343
}
303
344
304
- if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
305
- ggml_backend_sched_alloc_graph (sched.get (), gf);
306
-
307
- const size_t prefix_len = strlen (LLAMA_TENSOR_NAME_FATTN) + 1 ;
308
- bool fa_device_mismatch = false ;
309
- for (int i = 0 ; i < ggml_graph_n_nodes (gf); i++) {
310
- ggml_tensor * n = ggml_graph_node (gf, i);
311
- if (n->op != GGML_OP_FLASH_ATTN_EXT) {
312
- continue ;
313
- }
314
- ggml_backend_dev_t device_fa = ggml_backend_get_device (
315
- ggml_backend_sched_get_tensor_backend (sched.get (), n));
316
-
317
- // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer
318
- GGML_ASSERT (strncmp (n->name , LLAMA_TENSOR_NAME_FATTN " -" , prefix_len) == 0 );
319
- const int il = std::stoi (n->name + prefix_len);
320
- ggml_backend_dev_t device_kv = model.dev_layer (il);
321
- if (device_fa != device_kv) {
322
- LLAMA_LOG_WARN (" %s: layer %d is assigned to device %s but the Flash Attention tensor "
323
- " is assigned to device %s (usually due to missing support)\n " ,
324
- __func__, il, ggml_backend_dev_name (device_kv), ggml_backend_dev_name (device_fa));
325
- // FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways
326
- fa_device_mismatch = true ;
327
- break ;
328
- }
329
- }
330
- if (fa_device_mismatch) {
331
- cparams.flash_attn = false ;
332
- LLAMA_LOG_WARN (" %s: Flash Attention was auto, set to disabled\n " , __func__);
333
- if (ggml_is_quantized (params.type_v )) {
334
- throw std::runtime_error (" quantized V cache was requested, but this requires Flash Attention" );
335
- }
336
- auto * gf = graph_reserve (n_tokens, n_seqs, n_tokens, mctx.get ());
337
- if (!gf) {
338
- throw std::runtime_error (" failed to allocate compute pp buffers" );
339
- }
340
- } else {
341
- cparams.flash_attn = true ;
342
- LLAMA_LOG_INFO (" %s: Flash Attention was auto, set to enabled\n " , __func__);
343
- }
344
- }
345
-
346
345
n_splits_pp = ggml_backend_sched_get_n_splits (sched.get ());
347
346
n_nodes_pp = ggml_graph_n_nodes (gf);
348
347
}
@@ -1366,7 +1365,7 @@ llm_graph_result * llama_context::get_gf_res_reserve() const {
1366
1365
return static_cast <llm_graph_result *>(gf_res_reserve.get ());
1367
1366
}
1368
1367
1369
- ggml_cgraph * llama_context::graph_reserve (uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
1368
+ ggml_cgraph * llama_context::graph_reserve (uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only ) {
1370
1369
LLAMA_LOG_DEBUG (" %s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n " , __func__, n_tokens, n_seqs, n_outputs);
1371
1370
1372
1371
if (n_tokens % n_seqs != 0 ) {
@@ -1401,7 +1400,9 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
1401
1400
this ->n_outputs = save_n_outputs;
1402
1401
1403
1402
// initialize scheduler with the specified graph
1404
- if (!ggml_backend_sched_reserve (sched.get (), gf)) {
1403
+ if (split_only) {
1404
+ ggml_backend_sched_split_graph (sched.get (), gf);
1405
+ } else if (!ggml_backend_sched_reserve (sched.get (), gf)) {
1405
1406
LLAMA_LOG_ERROR (" %s: failed to allocate compute buffers\n " , __func__);
1406
1407
return nullptr ;
1407
1408
}
0 commit comments