Skip to content

Commit 172f616

Browse files
committed
cont : return important tensors
ggml-ci
1 parent c235903 commit 172f616

File tree

5 files changed

+293
-46
lines changed

5 files changed

+293
-46
lines changed

src/llama-context.cpp

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,8 @@ void llama_context::init() {
255255
// reserve pp graph first so that buffers are only allocated once
256256
{
257257
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
258-
auto res_pp = graph_build(ubatch_pp, true);
258+
auto ctx = graph_init();
259+
auto res_pp = graph_build(ctx, ubatch_pp, true);
259260
auto & gf_pp = res_pp.gf;
260261
if (!ggml_backend_sched_reserve(sched.get(), gf_pp)) {
261262
LLAMA_LOG_ERROR("%s: failed to allocate compute pp buffers\n", __func__);
@@ -269,7 +270,8 @@ void llama_context::init() {
269270
// reserve with tg graph to get the number of splits and nodes
270271
{
271272
llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
272-
auto res_tg = graph_build(ubatch_tg, true);
273+
auto ctx = graph_init();
274+
auto res_tg = graph_build(ctx, ubatch_tg, true);
273275
auto & gf_tg = res_tg.gf;
274276
if (!ggml_backend_sched_reserve(sched.get(), gf_tg)) {
275277
LLAMA_LOG_ERROR("%s: failed to allocate compute tg buffers\n", __func__);
@@ -282,7 +284,8 @@ void llama_context::init() {
282284
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
283285
{
284286
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
285-
auto res_pp = graph_build(ubatch_pp, true);
287+
auto ctx = graph_init();
288+
auto res_pp = graph_build(ctx, ubatch_pp, true);
286289
auto & gf_pp = res_pp.gf;
287290
if (!ggml_backend_sched_reserve(sched.get(), gf_pp)) {
288291
LLAMA_LOG_ERROR("%s: failed to allocate compute pp buffers\n", __func__);
@@ -569,6 +572,13 @@ ggml_context_ptr llama_context::graph_init() {
569572
return ggml_context_ptr { ggml_init(params) };
570573
}
571574

575+
llama_graph_result llama_context::graph_build(
576+
ggml_context_ptr & ctx,
577+
const llama_ubatch & ubatch,
578+
bool worst_case) {
579+
return model.build_graph(ctx, *this, cparams, ubatch, worst_case);
580+
}
581+
572582
enum ggml_status llama_context::graph_compute(
573583
ggml_cgraph * graph,
574584
bool batched) {
@@ -907,10 +917,6 @@ void llama_context::build_cb(
907917
}
908918
}
909919

910-
llama_graph_result llama_context::graph_build(const llama_ubatch & ubatch, bool worst_case) {
911-
return model.build_graph(*this, cparams, ubatch, graph_init(), worst_case);
912-
}
913-
914920
llama_perf_context_data llama_context::perf_get_data() const {
915921
llama_perf_context_data data = {};
916922

@@ -1831,7 +1837,8 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
18311837
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
18321838
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
18331839

1834-
auto res = graph_build(ubatch, true);
1840+
auto ctx = graph_init();
1841+
auto res = graph_build(ctx, ubatch, true);
18351842

18361843
// initialize scheduler with the worst-case graph
18371844
ggml_backend_sched_reset(sched.get());
@@ -1845,7 +1852,8 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
18451852
ggml_backend_sched_reset(sched.get());
18461853
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
18471854

1848-
auto res = graph_build(ubatch, false);
1855+
auto ctx = graph_init();
1856+
auto res = graph_build(ctx, ubatch, false);
18491857

18501858
auto & gf = res.gf;
18511859

@@ -2092,7 +2100,8 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) {
20922100
ggml_backend_sched_reset(sched.get());
20932101
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
20942102

2095-
auto res = graph_build(ubatch, false);
2103+
auto ctx = graph_init();
2104+
auto res = graph_build(ctx, ubatch, false);
20962105

20972106
auto & gf = res.gf;
20982107

src/llama-context.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,10 @@ struct llama_context : public llama_graph_i {
9696
virtual ggml_context_ptr graph_init();
9797

9898
// TODO: add encode/decode graphs
99-
virtual llama_graph_result graph_build(const llama_ubatch & ubatch, bool worst_case);
99+
virtual llama_graph_result graph_build(
100+
ggml_context_ptr & ctx,
101+
const llama_ubatch & ubatch,
102+
bool worst_case);
100103

101104
// returns the result of ggml_backend_sched_graph_compute_async execution
102105
virtual enum ggml_status graph_compute(

src/llama-graph.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@ struct llama_ubatch;
1313
struct llama_graph_result {
1414
ggml_cgraph * gf = nullptr;
1515

16-
ggml_tensor * t_logits = nullptr;
17-
ggml_tensor * t_embd = nullptr;
16+
// important graph nodes
17+
ggml_tensor * t_logits = nullptr;
18+
ggml_tensor * t_embd = nullptr;
19+
ggml_tensor * t_embd_pooled = nullptr;
1820
};
1921

2022
// TODO: can become more granular in the future

0 commit comments

Comments
 (0)