@@ -2714,6 +2714,7 @@ struct llama_model {
27142714
27152715// Object used to allow caching of GGML graph between tokens where possible.
27162716struct ggml_cached_graph {
2717+ bool is_active = false;
27172718 ggml_cgraph * gf;
27182719 size_t n;
27192720 ggml_backend_t backend_res;
@@ -14550,7 +14551,11 @@ static int llama_decode_internal(
1455014551
1455114552 gf = llama_build_graph(lctx, u_batch, false);
1455214553
14553- // disable future graph caching in presense of env var,
14554+ // Set whether GGML graph caching is in use within GGML module, based on
14555+ // whether caching was activated here during the previous token
14556+ ggml_set_cached_graph(lctx.sched,lctx.cached_graph.is_active);
14557+
14558+ // Disable future graph caching in presence of env var,
1455414559 // if there are multiple devices, or if batch size is greater than 1
1455514560 // TO DO enable graph caching for these cases
1455614561 bool disable_cached_ggml_graph = (getenv("GGML_DISABLE_GRAPH_CACHING") != nullptr)
@@ -14562,7 +14567,8 @@ static int llama_decode_internal(
1456214567 }
1456314568 }
1456414569
14565- if(!disable_cached_ggml_graph) ggml_set_cached_graph(lctx.sched,true);
14570+ // Set whether graph caching should be used for future tokens
14571+ lctx.cached_graph.is_active=!disable_cached_ggml_graph;
1456614572
1456714573 // the output is always the last tensor in the graph
1456814574 res = gf->nodes[gf->n_nodes - 1];
0 commit comments