Skip to content

Commit 0d9c3d4

Browse files
committed
llama : reuse compute graphs
ggml-ci
1 parent bac8bed commit 0d9c3d4

File tree

14 files changed

+397
-182
lines changed

14 files changed

+397
-182
lines changed

common/arg.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1464,6 +1464,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
14641464
params.swa_full = true;
14651465
}
14661466
).set_env("LLAMA_ARG_SWA_FULL"));
1467+
add_opt(common_arg(
1468+
{"--graph-reuse", "-gr"},
1469+
string_format("reuse previous compute graphs when possible (default: %s)"
1470+
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/14482)", params.graph_reuse ? "true" : "false"),
1471+
[](common_params & params) {
1472+
params.graph_reuse = true;
1473+
}
1474+
).set_env("LLAMA_ARG_GRAPH_REUSE"));
14671475
add_opt(common_arg(
14681476
{"--no-context-shift"},
14691477
string_format("disables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),

common/common.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1157,6 +1157,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
11571157
cparams.no_perf = params.no_perf;
11581158
cparams.op_offload = !params.no_op_offload;
11591159
cparams.swa_full = params.swa_full;
1160+
cparams.graph_reuse = params.graph_reuse;
11601161

11611162
cparams.type_k = params.cache_type_k;
11621163
cparams.type_v = params.cache_type_v;

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,7 @@ struct common_params {
330330
bool no_perf = false; // disable performance metrics
331331
bool ctx_shift = true; // context shift on inifinite text generation
332332
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
333+
bool graph_reuse = false; // reuse previous compute graphs when possible
333334

334335
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
335336
bool use_mmap = true; // use mmap for faster loads

include/llama.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,8 @@ extern "C" {
374374
bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
375375
// NOTE: setting to false when n_seq_max > 1 can cause bad performance in some cases
376376
// ref: https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573
377+
378+
bool graph_reuse; // reuse previous compute graphs when possible
377379
};
378380

379381
// model quantization parameters
@@ -1429,6 +1431,7 @@ extern "C" {
14291431

14301432
int32_t n_p_eval;
14311433
int32_t n_eval;
1434+
int32_t n_reused;
14321435
};
14331436

14341437
struct llama_perf_sampler_data {

src/llama-batch.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,31 @@ struct llama_ubatch {
3434
llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
3535
int32_t * seq_idx; // [LLAMA_MAX_SEQ] | - | seq_idx
3636
int8_t * output; // [n_tokens] | i | -
37+
38+
bool is_same(const llama_ubatch & other) const {
39+
bool res =
40+
equal_seqs == other.equal_seqs &&
41+
n_tokens == other.n_tokens &&
42+
n_seq_tokens == other.n_seq_tokens &&
43+
n_seqs == other.n_seqs &&
44+
n_seqs_unq == other.n_seqs_unq &&
45+
(
46+
(!token && !other.token) ||
47+
(!embd && !other.embd)
48+
);
49+
50+
if (!res) {
51+
return false;
52+
}
53+
54+
// TODO: this won't work because seq_id_unq ptr can point to an old balloc that has
55+
// been freed by this point. find a way to fix this
56+
//for (uint32_t s = 0; s < n_seqs_unq; ++s) {
57+
// res &= seq_id_unq[s] == other.seq_id_unq[s];
58+
//}
59+
60+
return res;
61+
}
3762
};
3863

3964
// a helper for sanitizing, fulfilling and splitting a batch

src/llama-context.cpp

Lines changed: 84 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@ llama_context::llama_context(
101101

102102
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
103103

104-
cparams.op_offload = params.op_offload;
104+
cparams.op_offload = params.op_offload;
105+
cparams.graph_reuse = params.graph_reuse;
105106

106107
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
107108

@@ -227,8 +228,8 @@ llama_context::llama_context(
227228

228229
LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);
229230

230-
// buffer used to store the computation graph and the tensor meta data
231-
buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
231+
gf_res_prev.reset(new llm_graph_result(max_nodes));
232+
gf_res_reserve.reset(new llm_graph_result(max_nodes));
232233

233234
// TODO: move these checks to ggml_backend_sched
234235
// enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
@@ -388,10 +389,6 @@ ggml_backend_sched_t llama_context::get_sched() const {
388389
return sched.get();
389390
}
390391

391-
ggml_context * llama_context::get_ctx_compute() const {
392-
return ctx_compute.get();
393-
}
394-
395392
uint32_t llama_context::n_ctx() const {
396393
return cparams.n_ctx;
397394
}
@@ -678,38 +675,50 @@ bool llama_context::apply_adapter_cvec(
678675
return cvec.apply(model, data, len, n_embd, il_start, il_end);
679676
}
680677

681-
llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
678+
llm_graph_result_i * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
682679
if (mctx && !mctx->apply()) {
683680
LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
684681
ret = GGML_STATUS_FAILED;
685682
return nullptr;
686683
}
687684

688-
auto * gf = graph_init();
689-
if (!gf) {
690-
LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
691-
ret = GGML_STATUS_FAILED;
692-
return nullptr;
693-
}
685+
auto * res = gf_res_prev.get();
686+
auto * gf = res->get_gf();
694687

695-
auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mctx);
696-
if (!res) {
697-
LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
698-
ret = GGML_STATUS_FAILED;
699-
return nullptr;
700-
}
688+
const auto gparams = graph_params(res, ubatch, mctx, gtype);
701689

702-
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
690+
const bool can_reuse = cparams.graph_reuse && res->update(gparams);
691+
if (can_reuse) {
692+
LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__);
693+
n_reused++;
694+
} else {
695+
res->reset();
703696

704-
if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
705-
LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
706-
ret = GGML_STATUS_ALLOC_FAILED;
707-
return nullptr;
697+
ggml_backend_sched_reset(sched.get());
698+
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
699+
700+
//const auto t_start_us = ggml_time_us();
701+
702+
gf = model.build_graph(gparams);
703+
704+
//LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
705+
706+
if (!gf) {
707+
LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
708+
ret = GGML_STATUS_FAILED;
709+
return nullptr;
710+
}
711+
712+
if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
713+
LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
714+
ret = GGML_STATUS_ALLOC_FAILED;
715+
return nullptr;
716+
}
708717
}
709718

710719
res->set_inputs(&ubatch);
711720

712-
const auto status = graph_compute(gf, ubatch.n_tokens > 1);
721+
const auto status = graph_compute(res->get_gf(), ubatch.n_tokens > 1);
713722
if (status != GGML_STATUS_SUCCESS) {
714723
LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status);
715724
ret = status;
@@ -767,6 +776,8 @@ int llama_context::encode(const llama_batch & batch_inp) {
767776

768777
n_outputs = n_tokens;
769778

779+
gf_res_prev->reset();
780+
770781
ggml_backend_sched_reset(sched.get());
771782
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
772783

@@ -778,7 +789,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
778789
cparams.causal_attn = false;
779790

780791
ggml_status status;
781-
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);
792+
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);
782793

783794
cparams.causal_attn = causal_attn_org;
784795

@@ -846,7 +857,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
846857

847858
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
848859
// overlap with device computation.
849-
ggml_backend_sched_reset(sched.get());
860+
//ggml_backend_sched_reset(sched.get());
850861

851862
// TODO: hacky solution
852863
if (model.arch == LLM_ARCH_T5 && t_embd) {
@@ -1005,11 +1016,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
10051016
n_outputs = n_outputs_new;
10061017
}
10071018

1008-
ggml_backend_sched_reset(sched.get());
1009-
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
1010-
10111019
ggml_status status;
1012-
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
1020+
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
10131021

10141022
if (!res) {
10151023
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
@@ -1192,7 +1200,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
11921200

11931201
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
11941202
// overlap with device computation.
1195-
ggml_backend_sched_reset(sched.get());
1203+
//ggml_backend_sched_reset(sched.get());
11961204

11971205
return 0;
11981206
}
@@ -1275,20 +1283,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
12751283
// graph
12761284
//
12771285

1278-
int32_t llama_context::graph_max_nodes() const {
1279-
return std::max<int32_t>(65536, 5*model.n_tensors());
1280-
}
1281-
1282-
ggml_cgraph * llama_context::graph_init() {
1283-
ggml_init_params params = {
1284-
/*.mem_size =*/ buf_compute_meta.size(),
1285-
/*.mem_buffer =*/ buf_compute_meta.data(),
1286-
/*.no_alloc =*/ true,
1287-
};
1288-
1289-
ctx_compute.reset(ggml_init(params));
1290-
1291-
return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
1286+
uint32_t llama_context::graph_max_nodes() const {
1287+
return std::max<uint32_t>(65536u, 5u*model.n_tensors());
12921288
}
12931289

12941290
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
@@ -1301,6 +1297,9 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
13011297
LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
13021298
}
13031299

1300+
gf_res_prev->reset();
1301+
ggml_backend_sched_reset(sched.get());
1302+
13041303
// store the n_outputs as it is, and restore it afterwards
13051304
// TODO: not sure if needed, might simplify in the future by removing this
13061305
const auto save_n_outputs = this->n_outputs;
@@ -1310,17 +1309,15 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
13101309
llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
13111310
llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
13121311

1313-
auto * gf = graph_init();
1314-
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx);
1312+
auto * res = gf_res_reserve.get();
13151313

1316-
this->n_outputs = save_n_outputs;
1314+
const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT);
13171315

1318-
if (!res) {
1319-
LLAMA_LOG_ERROR("%s: failed to build worst-case graph\n", __func__);
1320-
return nullptr;
1321-
}
1316+
res->reset();
13221317

1323-
ggml_backend_sched_reset(sched.get());
1318+
auto * gf = model.build_graph(gparams);
1319+
1320+
this->n_outputs = save_n_outputs;
13241321

13251322
// initialize scheduler with the specified graph
13261323
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
@@ -1331,28 +1328,27 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
13311328
return gf;
13321329
}
13331330

1334-
llm_graph_result_ptr llama_context::graph_build(
1335-
ggml_context * ctx,
1336-
ggml_cgraph * gf,
1337-
const llama_ubatch & ubatch,
1338-
llm_graph_type gtype,
1339-
const llama_memory_context_i * mctx) {
1340-
return model.build_graph(
1341-
{
1342-
/*.ctx =*/ ctx,
1343-
/*.arch =*/ model.arch,
1344-
/*.hparams =*/ model.hparams,
1345-
/*.cparams =*/ cparams,
1346-
/*.ubatch =*/ ubatch,
1347-
/*.sched =*/ sched.get(),
1348-
/*.backend_cpu =*/ backend_cpu,
1349-
/*.cvec =*/ &cvec,
1350-
/*.loras =*/ &loras,
1351-
/*.mctx =*/ mctx,
1352-
/*.cross =*/ &cross,
1353-
/*.n_outputs =*/ n_outputs,
1354-
/*.cb =*/ graph_get_cb(),
1355-
}, gf, gtype);
1331+
llm_graph_params llama_context::graph_params(
1332+
llm_graph_result_i * res,
1333+
const llama_ubatch & ubatch,
1334+
const llama_memory_context_i * mctx,
1335+
llm_graph_type gtype) const {
1336+
return {
1337+
/*.arch =*/ model.arch,
1338+
/*.hparams =*/ model.hparams,
1339+
/*.cparams =*/ cparams,
1340+
/*.ubatch =*/ ubatch,
1341+
/*.gtype =*/ gtype,
1342+
/*.sched =*/ sched.get(),
1343+
/*.backend_cpu =*/ backend_cpu,
1344+
/*.cvec =*/ &cvec,
1345+
/*.loras =*/ &loras,
1346+
/*.mctx =*/ mctx,
1347+
/*.cross =*/ &cross,
1348+
/*.n_outputs =*/ n_outputs,
1349+
/*.cb =*/ graph_get_cb(),
1350+
/*.res =*/ res,
1351+
};
13561352
}
13571353

13581354
ggml_status llama_context::graph_compute(
@@ -1930,6 +1926,7 @@ llama_perf_context_data llama_context::perf_get_data() const {
19301926
data.t_eval_ms = 1e-3 * t_eval_us;
19311927
data.n_p_eval = std::max(1, n_p_eval);
19321928
data.n_eval = std::max(1, n_eval);
1929+
data.n_reused = std::max(0, n_reused);
19331930

19341931
return data;
19351932
}
@@ -1938,6 +1935,7 @@ void llama_context::perf_reset() {
19381935
t_start_us = ggml_time_us();
19391936
t_eval_us = n_eval = 0;
19401937
t_p_eval_us = n_p_eval = 0;
1938+
n_reused = 0;
19411939
}
19421940

19431941
//
@@ -2064,8 +2062,13 @@ void llama_context::opt_epoch_iter(
20642062
break;
20652063
}
20662064

2067-
auto * gf = graph_init();
2068-
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx.get());
2065+
auto * res = gf_res_prev.get();
2066+
2067+
const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT);
2068+
2069+
res->reset();
2070+
2071+
auto * gf = model.build_graph(gparams);
20692072

20702073
struct ggml_context * ctx_compute_opt;
20712074
{
@@ -2187,6 +2190,7 @@ llama_context_params llama_context_default_params() {
21872190
/*.no_perf =*/ true,
21882191
/*.op_offload =*/ true,
21892192
/*.swa_full =*/ true,
2193+
/*.graph_reuse =*/ false,
21902194
};
21912195

21922196
return result;
@@ -2807,6 +2811,7 @@ void llama_perf_context_print(const llama_context * ctx) {
28072811
LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
28082812
__func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval);
28092813
LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval));
2814+
LLAMA_LOG_INFO("%s: graphs reused = %10d\n", __func__, data.n_reused);
28102815
}
28112816

28122817
void llama_perf_context_reset(llama_context * ctx) {

0 commit comments

Comments
 (0)