Skip to content

Commit f9971ef

Browse files
committed
llama : dedup reserve code
1 parent 972f91c commit f9971ef

File tree

1 file changed

+2
-48
lines changed

1 file changed

+2
-48
lines changed

src/llama.cpp

Lines changed: 2 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -7629,30 +7629,6 @@ static int llama_decode_impl(
76297629
return -3;
76307630
}
76317631

7632-
// reserve a worst case graph if needed
7633-
// TODO: extract to a function
7634-
if (lctx.need_reserve) {
7635-
const auto & cparams = lctx.cparams;
7636-
const auto & model = lctx.model;
7637-
7638-
// build worst-case graph
7639-
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
7640-
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
7641-
7642-
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
7643-
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
7644-
7645-
ggml_cgraph * gf = llama_build_graph(lctx, ubatch, true);
7646-
7647-
// initialize scheduler with the worst-case graph
7648-
ggml_backend_sched_reset(lctx.sched.get());
7649-
if (!ggml_backend_sched_reserve(lctx.sched.get(), gf)) {
7650-
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
7651-
}
7652-
7653-
lctx.need_reserve = false;
7654-
}
7655-
76567632
ggml_backend_sched_reset(lctx.sched.get());
76577633
ggml_backend_sched_set_eval_callback(lctx.sched.get(), lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
76587634

@@ -7889,30 +7865,8 @@ static int llama_encode_impl(
78897865

78907866
//batch_manager->prepare(ubatch);
78917867

7892-
// reserve a worst case graph if needed
7893-
// TODO: extract to a function
7894-
if (lctx.need_reserve) {
7895-
// TODO: extract to a function
7896-
const auto & cparams = lctx.cparams;
7897-
const auto & model = lctx.model;
7898-
7899-
// build worst-case graph
7900-
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
7901-
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
7902-
7903-
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
7904-
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
7905-
7906-
ggml_cgraph * gf = llama_build_graph(lctx, ubatch, true);
7907-
7908-
// initialize scheduler with the worst-case graph
7909-
ggml_backend_sched_reset(lctx.sched.get());
7910-
if (!ggml_backend_sched_reserve(lctx.sched.get(), gf)) {
7911-
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
7912-
}
7913-
7914-
lctx.need_reserve = false;
7915-
}
7868+
// TODO: do reserve
7869+
GGML_ASSERT(lctx.need_reserve == false);
79167870

79177871
ggml_backend_sched_reset(lctx.sched.get());
79187872
ggml_backend_sched_set_eval_callback(lctx.sched.get(), lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);

0 commit comments

Comments
 (0)