Skip to content

Commit c235903

Browse files
committed
graph : add llama_graph_result
ggml-ci
1 parent f0d3ff2 commit c235903

File tree

5 files changed

+167
-350
lines changed

5 files changed

+167
-350
lines changed

src/llama-context.cpp

Lines changed: 44 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -246,31 +246,48 @@ void llama_context::init() {
246246
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
247247
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
248248

249+
int n_splits_pp = -1;
250+
int n_nodes_pp = -1;
251+
252+
int n_splits_tg = -1;
253+
int n_nodes_tg = -1;
254+
249255
// reserve pp graph first so that buffers are only allocated once
250-
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
251-
ggml_cgraph * gf_pp = build_graph(ubatch_pp, true);
252-
if (!ggml_backend_sched_reserve(sched.get(), gf_pp)) {
253-
LLAMA_LOG_ERROR("%s: failed to allocate compute pp buffers\n", __func__);
254-
throw std::runtime_error("failed to allocate compute buffers");
256+
{
257+
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
258+
auto res_pp = graph_build(ubatch_pp, true);
259+
auto & gf_pp = res_pp.gf;
260+
if (!ggml_backend_sched_reserve(sched.get(), gf_pp)) {
261+
LLAMA_LOG_ERROR("%s: failed to allocate compute pp buffers\n", __func__);
262+
throw std::runtime_error("failed to allocate compute buffers");
263+
}
264+
265+
n_splits_pp = ggml_backend_sched_get_n_splits(sched.get());
266+
n_nodes_pp = ggml_graph_n_nodes(gf_pp);
255267
}
256-
int n_splits_pp = ggml_backend_sched_get_n_splits(sched.get());
257-
int n_nodes_pp = ggml_graph_n_nodes(gf_pp);
258268

259269
// reserve with tg graph to get the number of splits and nodes
260-
llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
261-
ggml_cgraph * gf_tg = build_graph(ubatch_tg, true);
262-
if (!ggml_backend_sched_reserve(sched.get(), gf_tg)) {
263-
LLAMA_LOG_ERROR("%s: failed to allocate compute tg buffers\n", __func__);
264-
throw std::runtime_error("failed to allocate compute buffers");
270+
{
271+
llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
272+
auto res_tg = graph_build(ubatch_tg, true);
273+
auto & gf_tg = res_tg.gf;
274+
if (!ggml_backend_sched_reserve(sched.get(), gf_tg)) {
275+
LLAMA_LOG_ERROR("%s: failed to allocate compute tg buffers\n", __func__);
276+
throw std::runtime_error("failed to allocate compute buffers");
277+
}
278+
n_splits_tg = ggml_backend_sched_get_n_splits(sched.get());
279+
n_nodes_tg = ggml_graph_n_nodes(gf_tg);
265280
}
266-
int n_splits_tg = ggml_backend_sched_get_n_splits(sched.get());
267-
int n_nodes_tg = ggml_graph_n_nodes(gf_tg);
268281

269282
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
270-
gf_pp = build_graph(ubatch_pp, true);
271-
if (!ggml_backend_sched_reserve(sched.get(), gf_pp)) {
272-
LLAMA_LOG_ERROR("%s: failed to allocate compute pp buffers\n", __func__);
273-
throw std::runtime_error("failed to allocate compute buffers");
283+
{
284+
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
285+
auto res_pp = graph_build(ubatch_pp, true);
286+
auto & gf_pp = res_pp.gf;
287+
if (!ggml_backend_sched_reserve(sched.get(), gf_pp)) {
288+
LLAMA_LOG_ERROR("%s: failed to allocate compute pp buffers\n", __func__);
289+
throw std::runtime_error("failed to allocate compute buffers");
290+
}
274291
}
275292

276293
for (size_t i = 0; i < backend_ptrs.size(); ++i) {
@@ -890,7 +907,7 @@ void llama_context::build_cb(
890907
}
891908
}
892909

893-
ggml_cgraph * llama_context::build_graph(const llama_ubatch & ubatch, bool worst_case) {
910+
llama_graph_result llama_context::graph_build(const llama_ubatch & ubatch, bool worst_case) {
894911
return model.build_graph(*this, cparams, ubatch, graph_init(), worst_case);
895912
}
896913

@@ -1814,11 +1831,11 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
18141831
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
18151832
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
18161833

1817-
ggml_cgraph * gf = build_graph(ubatch, true);
1834+
auto res = graph_build(ubatch, true);
18181835

18191836
// initialize scheduler with the worst-case graph
18201837
ggml_backend_sched_reset(sched.get());
1821-
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
1838+
if (!ggml_backend_sched_reserve(sched.get(), res.gf)) {
18221839
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
18231840
}
18241841

@@ -1828,7 +1845,9 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
18281845
ggml_backend_sched_reset(sched.get());
18291846
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
18301847

1831-
ggml_cgraph * gf = build_graph(ubatch, false);
1848+
auto res = graph_build(ubatch, false);
1849+
1850+
auto & gf = res.gf;
18321851

18331852
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
18341853

@@ -2073,7 +2092,9 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) {
20732092
ggml_backend_sched_reset(sched.get());
20742093
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
20752094

2076-
ggml_cgraph * gf = build_graph(ubatch, false);
2095+
auto res = graph_build(ubatch, false);
2096+
2097+
auto & gf = res.gf;
20772098

20782099
ggml_backend_sched_alloc_graph(sched.get(), gf);
20792100

src/llama-context.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ struct llama_context : public llama_graph_i {
9595
// zero-out inputs and create ggml_context
9696
virtual ggml_context_ptr graph_init();
9797

98+
// TODO: add encode/decode graphs
99+
virtual llama_graph_result graph_build(const llama_ubatch & ubatch, bool worst_case);
100+
98101
// returns the result of ggml_backend_sched_graph_compute_async execution
99102
virtual enum ggml_status graph_compute(
100103
ggml_cgraph * graph,
@@ -145,9 +148,6 @@ struct llama_context : public llama_graph_i {
145148
const llama_ubatch & ubatch,
146149
int il);
147150

148-
// TODO: add encode/decode graphs
149-
virtual ggml_cgraph * build_graph(const llama_ubatch & ubatch, bool worst_case);
150-
151151
// apply control vector for layer il
152152
virtual ggml_tensor * build_cvec(
153153
ggml_context * ctx0,

src/llama-graph.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,13 @@ struct ggml_context;
1010
struct ggml_tensor;
1111
struct llama_ubatch;
1212

13+
struct llama_graph_result {
14+
ggml_cgraph * gf = nullptr;
15+
16+
ggml_tensor * t_logits = nullptr;
17+
ggml_tensor * t_embd = nullptr;
18+
};
19+
1320
// TODO: can become more granular in the future
1421
class llama_graph_i {
1522
public:

0 commit comments

Comments
 (0)