@@ -2712,6 +2712,16 @@ struct llama_model {
27122712 }
27132713};
27142714
2715+ // Object used to allow caching of GGML graph between tokens where possible.
2716+ struct ggml_cached_graph {
2717+ ggml_cgraph * gf;
2718+ size_t n;
2719+ ggml_backend_t backend_res;
2720+ ggml_backend_t backend_embd;
2721+ struct ggml_tensor * res;
2722+ struct ggml_tensor * embd;
2723+ };
2724+
27152725struct llama_context {
27162726 llama_context(const llama_model & model) : model(model), t_start_us(model.t_start_us), t_load_us(model.t_load_us) {}
27172727 ~llama_context() {
@@ -2813,6 +2823,8 @@ struct llama_context {
28132823
28142824 // control vectors
28152825 struct llama_control_vector cvec;
2826+
2827+ struct ggml_cached_graph cached_graph;
28162828};
28172829
28182830static size_t llama_get_device_count(const llama_model & model) {
@@ -14524,12 +14536,37 @@ static int llama_decode_internal(
1452414536 ggml_backend_sched_reset(lctx.sched);
1452514537 ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
1452614538
14527- ggml_cgraph * gf = llama_build_graph(lctx, u_batch, false);
14528-
14539+ ggml_cgraph * gf;
1452914540 // the output is always the last tensor in the graph
14530- struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
14531- struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2];
14541+ struct ggml_tensor * res;
14542+ struct ggml_tensor * embd;
14543+
14544+ bool n_has_changed_since_last_token = false;
14545+ if(lctx.cached_graph.n != kv_self.n) n_has_changed_since_last_token = true;
14546+ lctx.cached_graph.n = kv_self.n;
14547+
14548+ // Re-build graph only if graph caching is not possible
14549+ if(!ggml_use_cached_graph(lctx.sched) || n_has_changed_since_last_token) {
14550+
14551+ gf = llama_build_graph(lctx, u_batch, false);
14552+
14553+ // disable future graph caching in presense of env var,
14554+ // if there are multiple devices, or if batch size is greater than 1
14555+ // TO DO enable graph caching for these cases
14556+ bool disable_cached_ggml_graph = (getenv("GGML_DISABLE_GRAPH_CACHING") != nullptr)
14557+ || (llama_get_device_count(model) > 1);
14558+ for (int i = 0 ; i < gf->n_nodes; i++) {
14559+ if (gf->nodes[i]->op == GGML_OP_ADD && gf->nodes[i]->src[1] && gf->nodes[i]->src[1]->ne[1] > 1) {
14560+ disable_cached_ggml_graph = true;
14561+ break;
14562+ }
14563+ }
14564+
14565+ if(!disable_cached_ggml_graph) ggml_set_cached_graph(lctx.sched,true);
1453214566
14567+ // the output is always the last tensor in the graph
14568+ res = gf->nodes[gf->n_nodes - 1];
14569+ embd = gf->nodes[gf->n_nodes - 2];
1453314570 if (lctx.n_outputs == 0) {
1453414571 // no output
1453514572 res = nullptr;
@@ -14545,10 +14582,71 @@ static int llama_decode_internal(
1454514582 embd = nullptr; // do not extract embeddings when not needed
1454614583 GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
1454714584 }
14585+ lctx.cached_graph.res = res;
14586+ lctx.cached_graph.embd = embd;
1454814587 // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
1454914588
1455014589 ggml_backend_sched_alloc_graph(lctx.sched, gf);
1455114590
14591+ }
14592+ else {
14593+ gf = lctx.cached_graph.gf;
14594+ res = lctx.cached_graph.res;
14595+ embd = lctx.cached_graph.embd;
14596+ }
14597+ lctx.cached_graph.gf = gf;
14598+
14599+ if(ggml_use_cached_graph(lctx.sched)) {
14600+
14601+ // If using flash attention, find mask node so it can be skipped when updating
14602+ // KV cache paramaters in cached graph nodes below
14603+ void * flash_attn_mask_node = nullptr;
14604+ if(cparams.flash_attn) {
14605+ for (int i = 0; i < gf->n_nodes; i++) {
14606+ ggml_tensor * node = gf->nodes[i];
14607+ if (node->op == GGML_OP_FLASH_ATTN_EXT) {
14608+ flash_attn_mask_node = node->src[3];
14609+ break;
14610+ }
14611+ }
14612+ }
14613+
14614+ // Temporarily store KV cache parameters that will need updated in cached graph.
14615+ const struct llama_hparams & hparams = model.hparams;
14616+ const int64_t n_layer = hparams.n_layer;
14617+ const int64_t kv_head = kv_self.head;
14618+ std::vector<void *> kv_cache_ptrs;
14619+ for (int il = 0; il < n_layer; ++il) {
14620+ const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
14621+ const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
14622+ ggml_tensor * tmp_tensor = kv_self.k_l[il];
14623+ size_t tmp_offset = (ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa))*kv_head;
14624+ kv_cache_ptrs.push_back(static_cast<char*>(tmp_tensor->data) + tmp_offset);
14625+ tmp_tensor = kv_self.v_l[il];
14626+ if (cparams.flash_attn) {
14627+ tmp_offset = (kv_head)*ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
14628+ } else {
14629+ tmp_offset = (kv_head)*ggml_element_size(kv_self.v_l[il]);
14630+ }
14631+ kv_cache_ptrs.push_back(static_cast<char*>(tmp_tensor->data) + tmp_offset);
14632+ }
14633+
14634+ // Update KV cache parameters in cached graph.
14635+ int copy_op_count = 0;
14636+ if(gf != nullptr && gf->nodes != nullptr){
14637+ for (int i = 0; i < gf->n_nodes; i++) {
14638+ ggml_tensor * node = gf->nodes[i];
14639+ if (node->op == GGML_OP_CPY) {
14640+ if (node != flash_attn_mask_node) {
14641+ node->src[1]->data = kv_cache_ptrs[copy_op_count];
14642+ copy_op_count++;
14643+ }
14644+ }
14645+ }
14646+ }
14647+
14648+ }
14649+
1455214650 llama_set_inputs(lctx, u_batch);
1455314651
1455414652 llama_graph_compute(lctx, gf, n_threads);
@@ -14571,11 +14669,15 @@ static int llama_decode_internal(
1457114669 // extract logits
1457214670 if (res) {
1457314671 ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res);
14574- GGML_ASSERT(backend_res != nullptr);
14575- GGML_ASSERT(lctx.logits != nullptr);
14576-
1457714672 float * logits_out = lctx.logits + n_outputs_prev*n_vocab;
1457814673 const int32_t n_outputs_new = lctx.n_outputs;
14674+ if(!ggml_use_cached_graph(lctx.sched))
14675+ lctx.cached_graph.backend_res = backend_res;
14676+ else
14677+ backend_res = lctx.cached_graph.backend_res;
14678+
14679+ GGML_ASSERT(backend_res != nullptr);
14680+ GGML_ASSERT(lctx.logits != nullptr);
1457914681
1458014682 if (n_outputs_new) {
1458114683 GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs);
@@ -14587,6 +14689,12 @@ static int llama_decode_internal(
1458714689 // extract embeddings
1458814690 if (embd) {
1458914691 ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd);
14692+
14693+
14694+ if(!ggml_use_cached_graph(lctx.sched))
14695+ lctx.cached_graph.backend_embd = backend_embd;
14696+ else
14697+ backend_embd = lctx.cached_graph.backend_embd;
1459014698 GGML_ASSERT(backend_embd != nullptr);
1459114699
1459214700 switch (cparams.pooling_type) {
0 commit comments