Skip to content

Commit f63aeec

Browse files
committed
llama : models now build their graphs using llama_graph_i
ggml-ci
1 parent 0ab50f1 commit f63aeec

File tree

6 files changed

+7457
-7441
lines changed

6 files changed

+7457
-7441
lines changed

src/llama-context.cpp

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,47 @@ bool llama_context::apply_adapter_cvec(
193193
return cvec.apply(model, data, len, n_embd, il_start, il_end);
194194
}
195195

196+
void llama_context::build_cb(
197+
ggml_tensor * cur,
198+
const char * name,
199+
int il) {
200+
if (il >= 0) {
201+
ggml_format_name(cur, "%s-%d", name, il);
202+
} else {
203+
ggml_set_name(cur, name);
204+
}
205+
206+
if (!cparams.offload_kqv) {
207+
if (strcmp(name, "kqv_merged_cont") == 0) {
208+
// all nodes between the KV store and the attention output are run on the CPU
209+
ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend_cpu);
210+
}
211+
}
212+
213+
// norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends
214+
// FIXME: fix in ggml_backend_sched
215+
const bool full_offload = model.params.n_gpu_layers > (int) model.hparams.n_layer;
216+
// TODO: during #11213, the requirement for ubatch.n_tokens < 32 was removed to simplify
217+
// not sure if this is still needed, but it can be brought back if needed
218+
//if (ubatch.n_tokens < 32 || full_offload) {
219+
if (full_offload) {
220+
if (il != -1 && strcmp(name, "norm") == 0) {
221+
const auto & dev_layer = model.dev_layer(il);
222+
for (auto & backend : backends) {
223+
if (ggml_backend_get_device(backend.get()) == dev_layer) {
224+
if (ggml_backend_supports_op(backend.get(), cur)) {
225+
ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend.get());
226+
}
227+
}
228+
}
229+
}
230+
}
231+
}
232+
233+
ggml_cgraph * llama_context::build_graph(const llama_ubatch & ubatch, bool worst_case) {
234+
return model.build_graph(*this, cparams, ubatch, init(), worst_case);
235+
}
236+
196237
llama_perf_context_data llama_context::perf_get_data() const {
197238
llama_perf_context_data data = {};
198239

@@ -298,11 +339,7 @@ void llama_context::perf_reset() {
298339

299340
llama_context_unified::llama_context_unified(
300341
const llama_model & model,
301-
const llama_context_params & params,
302-
build_graph_callback && cb_build_graph) :
303-
llama_context(model),
304-
cb_build_graph(std::move(cb_build_graph)) {
305-
342+
const llama_context_params & params) : llama_context(model) {
306343
const auto & hparams = model.hparams;
307344

308345
cparams.n_seq_max = std::max(1u, params.n_seq_max);
@@ -555,7 +592,7 @@ llama_context_unified::llama_context_unified(
555592
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
556593

557594
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
558-
ggml_cgraph * gf_pp = this->cb_build_graph(*this, ubatch_pp, true);
595+
ggml_cgraph * gf_pp = build_graph(ubatch_pp, true);
559596

560597
// reserve pp graph first so that buffers are only allocated once
561598
ggml_backend_sched_reserve(sched.get(), gf_pp);
@@ -564,13 +601,13 @@ llama_context_unified::llama_context_unified(
564601

565602
// reserve with tg graph to get the number of splits and nodes
566603
llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
567-
ggml_cgraph * gf_tg = this->cb_build_graph(*this, ubatch_tg, true);
604+
ggml_cgraph * gf_tg = build_graph(ubatch_tg, true);
568605
ggml_backend_sched_reserve(sched.get(), gf_tg);
569606
int n_splits_tg = ggml_backend_sched_get_n_splits(sched.get());
570607
int n_nodes_tg = ggml_graph_n_nodes(gf_tg);
571608

572609
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
573-
gf_pp = this->cb_build_graph(*this, ubatch_pp, true);
610+
gf_pp = build_graph(ubatch_pp, true);
574611
if (!ggml_backend_sched_reserve(sched.get(), gf_pp)) {
575612
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
576613
throw std::runtime_error("failed to allocate compute buffers");
@@ -893,7 +930,7 @@ struct llama_context_unified::batch_manager {
893930
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
894931
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
895932

896-
ggml_cgraph * gf = lctx.cb_build_graph(lctx, ubatch, true);
933+
ggml_cgraph * gf = lctx.build_graph(ubatch, true);
897934

898935
// initialize scheduler with the worst-case graph
899936
ggml_backend_sched_reset(lctx.sched.get());
@@ -1004,7 +1041,7 @@ int llama_context_unified::decode(llama_batch & inp_batch) {
10041041
ggml_backend_sched_reset(sched.get());
10051042
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
10061043

1007-
ggml_cgraph * gf = cb_build_graph(*this, ubatch, false);
1044+
ggml_cgraph * gf = build_graph(ubatch, false);
10081045

10091046
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
10101047

@@ -1227,7 +1264,7 @@ int llama_context_unified::encode(llama_batch & inp_batch) {
12271264
ggml_backend_sched_reset(sched.get());
12281265
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
12291266

1230-
ggml_cgraph * gf = cb_build_graph(*this, ubatch, false);
1267+
ggml_cgraph * gf = build_graph(ubatch, false);
12311268

12321269
ggml_backend_sched_alloc_graph(sched.get(), gf);
12331270

src/llama-context.h

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,14 @@ struct llama_context : public llama_graph_i {
8282
int32_t il_start,
8383
int32_t il_end);
8484

85+
virtual void build_cb(
86+
ggml_tensor * cur,
87+
const char * name,
88+
int il);
89+
90+
// TODO: add encode/decode graphs
91+
virtual ggml_cgraph * build_graph(const llama_ubatch & ubatch, bool worst_case);
92+
8593
// decode a batch of tokens by evaluating the transformer
8694
// in case of unsuccessful decoding (error or warning),
8795
// the kv_cache state will be returned to its original state
@@ -171,11 +179,6 @@ struct llama_context : public llama_graph_i {
171179

172180
// members
173181

174-
// TODO: temporary public until llama_context implements the graph build function
175-
std::vector<ggml_backend_ptr> backends;
176-
ggml_backend_t backend_cpu = nullptr;
177-
ggml_backend_sched_ptr sched;
178-
179182
protected:
180183
const llama_model & model;
181184

@@ -189,8 +192,13 @@ struct llama_context : public llama_graph_i {
189192
ggml_abort_callback abort_callback = nullptr;
190193
void * abort_callback_data = nullptr;
191194

195+
ggml_backend_t backend_cpu = nullptr;
196+
std::vector<ggml_backend_ptr> backends;
197+
192198
std::vector<std::pair<ggml_backend_t, ggml_backend_set_n_threads_t>> set_n_threads_fns;
193199

200+
ggml_backend_sched_ptr sched;
201+
194202
// memory buffers used to evaluate the model
195203
std::vector<uint8_t> buf_compute_meta;
196204

@@ -213,13 +221,9 @@ class llama_context_unified : public llama_context {
213221
public:
214222
struct batch_manager;
215223

216-
// TODO: tmp until llama_model starts implementing the graph build function
217-
typedef std::function<ggml_cgraph *(llama_context &, const llama_ubatch &, bool worst_case)> build_graph_callback;
218-
219224
llama_context_unified(
220225
const llama_model & model,
221-
const llama_context_params & params,
222-
build_graph_callback && cb_build_graph);
226+
const llama_context_params & params);
223227

224228
virtual ~llama_context_unified();
225229

@@ -244,8 +248,6 @@ class llama_context_unified : public llama_context {
244248

245249
llama_sbatch sbatch;
246250

247-
build_graph_callback cb_build_graph;
248-
249251
// host buffer for the model output (logits and embeddings)
250252
ggml_backend_buffer_ptr buf_output;
251253

src/llama-graph.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,15 @@ struct ggml_context;
77
struct ggml_tensor;
88
struct llama_ubatch;
99

10-
// TODO: pass to llama_model graph build
10+
// TODO: can become more granular in the future
1111
class llama_graph_i {
1212
public:
13+
// callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
14+
virtual void build_cb(
15+
ggml_tensor * cur,
16+
const char * name,
17+
int il) = 0;
18+
1319
// apply control vector for layer il
1420
virtual ggml_tensor * build_cvec(
1521
ggml_context * ctx0,

0 commit comments

Comments
 (0)