@@ -193,6 +193,47 @@ bool llama_context::apply_adapter_cvec(
193193 return cvec.apply (model, data, len, n_embd, il_start, il_end);
194194}
195195
196+ void llama_context::build_cb (
197+ ggml_tensor * cur,
198+ const char * name,
199+ int il) {
200+ if (il >= 0 ) {
201+ ggml_format_name (cur, " %s-%d" , name, il);
202+ } else {
203+ ggml_set_name (cur, name);
204+ }
205+
206+ if (!cparams.offload_kqv ) {
207+ if (strcmp (name, " kqv_merged_cont" ) == 0 ) {
208+ // all nodes between the KV store and the attention output are run on the CPU
209+ ggml_backend_sched_set_tensor_backend (sched.get (), cur, backend_cpu);
210+ }
211+ }
212+
213+ // norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends
214+ // FIXME: fix in ggml_backend_sched
215+ const bool full_offload = model.params .n_gpu_layers > (int ) model.hparams .n_layer ;
216+ // TODO: during #11213, the requirement for ubatch.n_tokens < 32 was removed to simplify
217+ // not sure if this is still needed, but it can be brought back if needed
218+ // if (ubatch.n_tokens < 32 || full_offload) {
219+ if (full_offload) {
220+ if (il != -1 && strcmp (name, " norm" ) == 0 ) {
221+ const auto & dev_layer = model.dev_layer (il);
222+ for (auto & backend : backends) {
223+ if (ggml_backend_get_device (backend.get ()) == dev_layer) {
224+ if (ggml_backend_supports_op (backend.get (), cur)) {
225+ ggml_backend_sched_set_tensor_backend (sched.get (), cur, backend.get ());
226+ }
227+ }
228+ }
229+ }
230+ }
231+ }
232+
233+ ggml_cgraph * llama_context::build_graph (const llama_ubatch & ubatch, bool worst_case) {
234+ return model.build_graph (*this , cparams, ubatch, init (), worst_case);
235+ }
236+
196237llama_perf_context_data llama_context::perf_get_data () const {
197238 llama_perf_context_data data = {};
198239
@@ -298,11 +339,7 @@ void llama_context::perf_reset() {
298339
299340llama_context_unified::llama_context_unified (
300341 const llama_model & model,
301- const llama_context_params & params,
302- build_graph_callback && cb_build_graph) :
303- llama_context(model),
304- cb_build_graph(std::move(cb_build_graph)) {
305-
342+ const llama_context_params & params) : llama_context(model) {
306343 const auto & hparams = model.hparams ;
307344
308345 cparams.n_seq_max = std::max (1u , params.n_seq_max );
@@ -555,7 +592,7 @@ llama_context_unified::llama_context_unified(
555592 llama_token token = model.vocab .token_bos (); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
556593
557594 llama_ubatch ubatch_pp = { true , n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr , nullptr , nullptr , nullptr , nullptr };
558- ggml_cgraph * gf_pp = this -> cb_build_graph (* this , ubatch_pp, true );
595+ ggml_cgraph * gf_pp = build_graph ( ubatch_pp, true );
559596
560597 // reserve pp graph first so that buffers are only allocated once
561598 ggml_backend_sched_reserve (sched.get (), gf_pp);
@@ -564,13 +601,13 @@ llama_context_unified::llama_context_unified(
564601
565602 // reserve with tg graph to get the number of splits and nodes
566603 llama_ubatch ubatch_tg = { true , 1 , 1 , n_seqs, &token, nullptr , nullptr , nullptr , nullptr , nullptr };
567- ggml_cgraph * gf_tg = this -> cb_build_graph (* this , ubatch_tg, true );
604+ ggml_cgraph * gf_tg = build_graph ( ubatch_tg, true );
568605 ggml_backend_sched_reserve (sched.get (), gf_tg);
569606 int n_splits_tg = ggml_backend_sched_get_n_splits (sched.get ());
570607 int n_nodes_tg = ggml_graph_n_nodes (gf_tg);
571608
572609 // reserve again with pp graph to avoid ggml-alloc reallocations during inference
573- gf_pp = this -> cb_build_graph (* this , ubatch_pp, true );
610+ gf_pp = build_graph ( ubatch_pp, true );
574611 if (!ggml_backend_sched_reserve (sched.get (), gf_pp)) {
575612 LLAMA_LOG_ERROR (" %s: failed to allocate compute buffers\n " , __func__);
576613 throw std::runtime_error (" failed to allocate compute buffers" );
@@ -893,7 +930,7 @@ struct llama_context_unified::batch_manager {
893930 llama_token token = model.vocab .token_bos (); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
894931 llama_ubatch ubatch = { true , n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr , nullptr , nullptr , nullptr , nullptr };
895932
896- ggml_cgraph * gf = lctx.cb_build_graph (lctx, ubatch, true );
933+ ggml_cgraph * gf = lctx.build_graph ( ubatch, true );
897934
898935 // initialize scheduler with the worst-case graph
899936 ggml_backend_sched_reset (lctx.sched .get ());
@@ -1004,7 +1041,7 @@ int llama_context_unified::decode(llama_batch & inp_batch) {
10041041 ggml_backend_sched_reset (sched.get ());
10051042 ggml_backend_sched_set_eval_callback (sched.get (), cparams.cb_eval , cparams.cb_eval_user_data );
10061043
1007- ggml_cgraph * gf = cb_build_graph (* this , ubatch, false );
1044+ ggml_cgraph * gf = build_graph ( ubatch, false );
10081045
10091046 // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
10101047
@@ -1227,7 +1264,7 @@ int llama_context_unified::encode(llama_batch & inp_batch) {
12271264 ggml_backend_sched_reset (sched.get ());
12281265 ggml_backend_sched_set_eval_callback (sched.get (), cparams.cb_eval , cparams.cb_eval_user_data );
12291266
1230- ggml_cgraph * gf = cb_build_graph (* this , ubatch, false );
1267+ ggml_cgraph * gf = build_graph ( ubatch, false );
12311268
12321269 ggml_backend_sched_alloc_graph (sched.get (), gf);
12331270
0 commit comments