@@ -274,13 +274,16 @@ llama_context::llama_context(
274274 // simulate full KV cache
275275 llama_kv_cache * kv_self = static_cast <llama_kv_cache *>(memory.get ());
276276
277- kv_self->set_full ();
277+ const auto kv_state = kv_self->init_full ();
278+ if (!kv_state) {
279+ throw std::runtime_error (" failed to initialize KV cache" );
280+ }
278281
279282 cross.v_embd .clear ();
280283
281284 // reserve pp graph first so that buffers are only allocated once
282285 {
283- auto * gf = graph_reserve (n_tokens, n_seqs, n_tokens);
286+ auto * gf = graph_reserve (n_tokens, n_seqs, n_tokens, kv_state. get () );
284287 if (!gf) {
285288 throw std::runtime_error (" failed to allocate compute pp buffers" );
286289 }
@@ -291,7 +294,7 @@ llama_context::llama_context(
291294
292295 // reserve with tg graph to get the number of splits and nodes
293296 {
294- auto * gf = graph_reserve (1 , 1 , 1 );
297+ auto * gf = graph_reserve (1 , 1 , 1 , kv_state. get () );
295298 if (!gf) {
296299 throw std::runtime_error (" failed to allocate compute tg buffers" );
297300 }
@@ -302,7 +305,7 @@ llama_context::llama_context(
302305
303306 // reserve again with pp graph to avoid ggml-alloc reallocations during inference
304307 {
305- auto * gf = graph_reserve (n_tokens, n_seqs, n_tokens);
308+ auto * gf = graph_reserve (n_tokens, n_seqs, n_tokens, kv_state. get () );
306309 if (!gf) {
307310 throw std::runtime_error (" failed to allocate compute pp buffers" );
308311 }
@@ -430,12 +433,15 @@ void llama_context::kv_self_update() {
430433
431434 if (kv_self->update (*this )) {
432435 // if the KV cache did any computation, we have to reserve a new worst-case graph
433- kv_self->set_full ();
436+ const auto kv_state = kv_self->init_full ();
437+ if (!kv_state) {
438+ throw std::runtime_error (" failed to initialize KV cache" );
439+ }
434440
435441 const uint32_t n_seqs = cparams.n_seq_max ;
436442 const uint32_t n_tokens = std::min (cparams.n_ctx , cparams.n_ubatch );
437443
438- auto * gf = graph_reserve (n_tokens, n_seqs, n_tokens);
444+ auto * gf = graph_reserve (n_tokens, n_seqs, n_tokens, kv_state. get () );
439445 if (!gf) {
440446 LLAMA_LOG_ERROR (" %s: failed to reserve graph after the KV cache update\n " , __func__);
441447 }
@@ -651,7 +657,7 @@ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch,
651657 return nullptr ;
652658 }
653659
654- auto res = graph_build (ctx_compute.get (), gf, ubatch, gtype);
660+ auto res = graph_build (ctx_compute.get (), gf, ubatch, gtype, mstate );
655661 if (!res) {
656662 LLAMA_LOG_ERROR (" %s: failed to build graph\n " , __func__);
657663 if (ret) {
@@ -1269,7 +1275,7 @@ ggml_cgraph * llama_context::graph_init() {
12691275 return ggml_new_graph_custom (ctx_compute.get (), graph_max_nodes (), false );
12701276}
12711277
1272- ggml_cgraph * llama_context::graph_reserve (uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs) {
1278+ ggml_cgraph * llama_context::graph_reserve (uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate ) {
12731279 LLAMA_LOG_DEBUG (" %s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n " , __func__, n_tokens, n_seqs, n_outputs);
12741280
12751281 if (n_tokens % n_seqs != 0 ) {
@@ -1289,7 +1295,7 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
12891295 llama_ubatch ubatch = { true , n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr , nullptr , nullptr , nullptr , nullptr };
12901296
12911297 auto * gf = graph_init ();
1292- auto res = graph_build (ctx_compute.get (), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
1298+ auto res = graph_build (ctx_compute.get (), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate );
12931299
12941300 this ->n_outputs = save_n_outputs;
12951301
@@ -1310,10 +1316,11 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
13101316}
13111317
13121318llm_graph_result_ptr llama_context::graph_build (
1313- ggml_context * ctx,
1314- ggml_cgraph * gf,
1315- const llama_ubatch & ubatch,
1316- llm_graph_type gtype) {
1319+ ggml_context * ctx,
1320+ ggml_cgraph * gf,
1321+ const llama_ubatch & ubatch,
1322+ llm_graph_type gtype,
1323+ const llama_memory_state_i * mstate) {
13171324 return model.build_graph (
13181325 {
13191326 /* .ctx =*/ ctx,
@@ -1325,7 +1332,7 @@ llm_graph_result_ptr llama_context::graph_build(
13251332 /* .backend_cpu =*/ backend_cpu,
13261333 /* .cvec =*/ &cvec,
13271334 /* .loras =*/ &loras,
1328- /* .memory =*/ memory. get () ,
1335+ /* .mstate =*/ mstate ,
13291336 /* .cross =*/ &cross,
13301337 /* .n_outputs =*/ n_outputs,
13311338 /* .cb =*/ graph_get_cb (),
@@ -2047,7 +2054,7 @@ void llama_context::opt_epoch_iter(
20472054 n_outputs = ubatch.n_tokens ;
20482055
20492056 auto * gf = graph_init ();
2050- auto res = graph_build (ctx_compute.get (), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
2057+ auto res = graph_build (ctx_compute.get (), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, kv_state. get () );
20512058
20522059 struct ggml_context * ctx_compute_opt;
20532060 {
0 commit comments