77#include "ggml.h"
88#include "ggml-alloc.h"
99#include "ggml-backend.h"
10+ #include "../ggml/src/ggml-impl.h"
1011
1112#if defined(GGML_USE_VULKAN)
1213# include "ggml-vulkan.h"
@@ -3254,6 +3255,17 @@ struct llama_sbatch {
32543255 }
32553256};
32563257
3258+ // Object used to allow caching of GGML graph between tokens where possible.
3259+ struct ggml_cached_graph {
3260+ bool is_active = false;
3261+ ggml_cgraph * gf;
3262+ size_t n;
3263+ ggml_backend_t backend_res;
3264+ ggml_backend_t backend_embd;
3265+ struct ggml_tensor * res;
3266+ struct ggml_tensor * embd;
3267+ };
3268+
32573269struct llama_context {
32583270 llama_context(const llama_model & model)
32593271 : model(model)
@@ -3352,6 +3364,8 @@ struct llama_context {
33523364 struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch]
33533365 struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
33543366 struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
3367+
3368+ struct ggml_cached_graph cached_graph;
33553369};
33563370
33573371struct llama_lora_weight {
@@ -9146,7 +9160,6 @@ static void llm_build_kv_store(
91469160 v_cur = ggml_transpose(ctx, v_cur);
91479161 }
91489162 cb(v_cache_view, "v_cache_view", il);
9149-
91509163 ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view));
91519164}
91529165
@@ -17181,11 +17194,44 @@ static int llama_decode_internal(
1718117194 ggml_backend_sched_reset(lctx.sched);
1718217195 ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
1718317196
17184- ggml_cgraph * gf = llama_build_graph(lctx, ubatch, false);
17197+ ggml_cgraph * gf;
17198+ // the output is always the last tensor in the graph
17199+ struct ggml_tensor * res;
17200+ struct ggml_tensor * embd;
17201+
17202+ bool n_has_changed_since_last_token = false;
17203+ if(lctx.cached_graph.n != kv_self.n) n_has_changed_since_last_token = true;
17204+ lctx.cached_graph.n = kv_self.n;
17205+
17206+ // Re-build graph only if graph caching is not possible
17207+ if(!ggml_use_cached_graph(lctx.sched) || n_has_changed_since_last_token) {
17208+
17209+ gf = llama_build_graph(lctx, ubatch, false);
17210+
17211+ // Set whether GGML graph caching is in use within GGML module, based on
17212+ // whether caching was activated here during the previous token
17213+ ggml_set_cached_graph(lctx.sched,lctx.cached_graph.is_active);
17214+
17215+ // Disable future graph caching in presence of env var,
17216+ // if there are multiple devices, if batch size is greater than 1,
17217+ // or if nsplits is not 2.
17218+ // TO DO enable graph caching for these cases
17219+ bool disable_cached_ggml_graph = (getenv("GGML_DISABLE_GRAPH_CACHING") != nullptr)
17220+ || (llama_get_device_count(model) > 1)
17221+ || (ggml_backend_sched_get_n_splits(lctx.sched) != 2);
17222+ for (int i = 0 ; i < ggml_graph_n_nodes(gf); i++) {
17223+ if (gf->nodes[i]->op == GGML_OP_ADD && gf->nodes[i]->src[1] && gf->nodes[i]->src[1]->ne[1] > 1) {
17224+ disable_cached_ggml_graph = true;
17225+ break;
17226+ }
17227+ }
17228+
17229+ // Set whether graph caching should be used for future tokens
17230+ lctx.cached_graph.is_active=!disable_cached_ggml_graph;
1718517231
1718617232 // the output is always the last tensor in the graph
17187- struct ggml_tensor * res = ggml_graph_node(gf, -1);
17188- struct ggml_tensor * embd = ggml_graph_node(gf, -2);
17233+ res = ggml_graph_node(gf, -1);
17234+ embd = ggml_graph_node(gf, -2);
1718917235
1719017236 if (lctx.n_outputs == 0) {
1719117237 // no output
@@ -17205,10 +17251,60 @@ static int llama_decode_internal(
1720517251 embd = nullptr; // do not extract embeddings when not needed
1720617252 GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
1720717253 }
17254+ lctx.cached_graph.res = res;
17255+ lctx.cached_graph.embd = embd;
1720817256 // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
1720917257
1721017258 ggml_backend_sched_alloc_graph(lctx.sched, gf);
1721117259
17260+ }
17261+ else {
17262+ gf = lctx.cached_graph.gf;
17263+ res = lctx.cached_graph.res;
17264+ embd = lctx.cached_graph.embd;
17265+ }
17266+ lctx.cached_graph.gf = gf;
17267+
17268+ // Update K and V cache parameters in cached graph.
17269+ if(gf != nullptr && gf->nodes != nullptr && ggml_use_cached_graph(lctx.sched)) {
17270+
17271+ const struct llama_hparams & hparams = model.hparams;
17272+ const int64_t kv_head = kv_self.head;
17273+
17274+ for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
17275+ ggml_tensor * node = gf->nodes[i];
17276+ if (node->op == GGML_OP_CPY) {
17277+
17278+ // K cache
17279+ const char* k_prefix = "k_cache_view-";
17280+ if (strncmp(node->src[1]->name, k_prefix, strlen(k_prefix)) == 0) {
17281+ int il = atoi(node->src[1]->name + strlen(k_prefix)); // Layer index from name
17282+ const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
17283+ ggml_tensor * tmp_tensor = kv_self.k_l[il];
17284+ size_t tmp_offset = (ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa))*kv_head;
17285+ node->src[1]->data = static_cast<char*>(tmp_tensor->data) + tmp_offset;
17286+ }
17287+
17288+ // V cache
17289+ const char* v_prefix = "v_cache_view-";
17290+ if (strncmp(node->src[1]->name, v_prefix, strlen(v_prefix)) == 0) {
17291+ int il = atoi(node->src[1]->name + strlen(v_prefix)); // Layer index from name
17292+ const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
17293+ ggml_tensor * tmp_tensor = kv_self.v_l[il];
17294+ size_t tmp_offset;
17295+ if (cparams.flash_attn) {
17296+ tmp_offset = (kv_head)*ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
17297+ } else {
17298+ tmp_offset = (kv_head)*ggml_element_size(kv_self.v_l[il]);
17299+ }
17300+ node->src[1]->data = static_cast<char*>(tmp_tensor->data) + tmp_offset;
17301+ }
17302+
17303+ }
17304+ }
17305+
17306+ }
17307+
1721217308 llama_set_inputs(lctx, ubatch);
1721317309
1721417310 llama_graph_compute(lctx, gf, n_threads, threadpool);
@@ -17231,11 +17327,15 @@ static int llama_decode_internal(
1723117327 // extract logits
1723217328 if (res) {
1723317329 ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res);
17234- GGML_ASSERT(backend_res != nullptr);
17235- GGML_ASSERT(lctx.logits != nullptr);
17236-
1723717330 float * logits_out = lctx.logits + n_outputs_prev*n_vocab;
1723817331 const int32_t n_outputs_new = lctx.n_outputs;
17332+ if(!ggml_use_cached_graph(lctx.sched))
17333+ lctx.cached_graph.backend_res = backend_res;
17334+ else
17335+ backend_res = lctx.cached_graph.backend_res;
17336+
17337+ GGML_ASSERT(backend_res != nullptr);
17338+ GGML_ASSERT(lctx.logits != nullptr);
1723917339
1724017340 if (n_outputs_new) {
1724117341 GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs);
@@ -17247,6 +17347,12 @@ static int llama_decode_internal(
1724717347 // extract embeddings
1724817348 if (embd) {
1724917349 ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd);
17350+
17351+
17352+ if(!ggml_use_cached_graph(lctx.sched))
17353+ lctx.cached_graph.backend_embd = backend_embd;
17354+ else
17355+ backend_embd = lctx.cached_graph.backend_embd;
1725017356 GGML_ASSERT(backend_embd != nullptr);
1725117357
1725217358 switch (cparams.pooling_type) {
0 commit comments