Skip to content

Commit 2749662

Browse files
authored
llama : fix fattn reserve call n_seqs parameter (ggml-org#15699)
ggml-ci
1 parent 9777032 commit 2749662

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

src/llama-context.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -281,9 +281,15 @@ llama_context::llama_context(
281281
}
282282

283283
cross.v_embd.clear();
284+
285+
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
286+
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
287+
288+
LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
289+
284290
// resolve automatic Flash Attention use
285291
if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
286-
auto * gf = graph_reserve(1, 1, 0, mctx.get(), true);
292+
auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true);
287293
if (!gf) {
288294
throw std::runtime_error("failed to split graph for Flash Attention check");
289295
}
@@ -324,11 +330,6 @@ llama_context::llama_context(
324330
}
325331

326332
// reserve worst-case graph
327-
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
328-
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
329-
330-
LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
331-
332333
int n_splits_pp = -1;
333334
int n_nodes_pp = -1;
334335

0 commit comments

Comments
 (0)