@@ -424,28 +424,33 @@ const llama_kv_cache * llama_context::get_kv_self() const {
424424 return kv_self;
425425}
426426
427- void llama_context::kv_self_update () {
427+ bool llama_context::kv_self_update () {
428428 if (!memory) {
429- return ;
429+ return false ;
430430 }
431431
432432 llama_kv_cache * kv_self = static_cast <llama_kv_cache *>(memory.get ());
433433
434- if (kv_self->update (*this )) {
435- // if the KV cache did any computation, we have to reserve a new worst-case graph
436- const auto kv_state = kv_self->init_full ();
437- if (!kv_state) {
438- throw std::runtime_error (" failed to initialize KV cache" );
439- }
434+ if (!kv_self->update (*this )) {
435+ // no updates have been performed
436+ return false ;
437+ }
440438
441- const uint32_t n_seqs = cparams.n_seq_max ;
442- const uint32_t n_tokens = std::min (cparams.n_ctx , cparams.n_ubatch );
439+ // if the KV cache did any computation, we have to reserve a new worst-case graph
440+ const auto kv_state = kv_self->init_full ();
441+ if (!kv_state) {
442+ throw std::runtime_error (" failed to initialize KV cache" );
443+ }
443444
444- auto * gf = graph_reserve (n_tokens, n_seqs, n_tokens, kv_state.get ());
445- if (!gf) {
446- LLAMA_LOG_ERROR (" %s: failed to reserve graph after the KV cache update\n " , __func__);
447- }
445+ const uint32_t n_seqs = cparams.n_seq_max ;
446+ const uint32_t n_tokens = std::min (cparams.n_ctx , cparams.n_ubatch );
447+
448+ auto * gf = graph_reserve (n_tokens, n_seqs, n_tokens, kv_state.get ());
449+ if (!gf) {
450+ LLAMA_LOG_ERROR (" %s: failed to reserve graph after the KV cache update\n " , __func__);
448451 }
452+
453+ return true ;
449454}
450455
451456enum llama_pooling_type llama_context::pooling_type () const {
@@ -933,24 +938,44 @@ int llama_context::decode(llama_batch & inp_batch) {
933938 // handle any pending defrags/shifts
934939 kv_self_update ();
935940
936- auto kv_state = kv_self->init_batch (batch, cparams.n_ubatch , embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
937- if (!kv_state) {
938- return -2 ;
939- }
941+ llama_memory_state_ptr kv_state;
940942
941- switch (kv_state->get_status ()) {
942- case LLAMA_MEMORY_STATUS_SUCCESS:
943- {
944- } break ;
945- case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
946- {
947- // not a fatal error, we can re-try with a different batch
948- return 1 ;
949- }
950- case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
951- {
952- return -2 ;
953- }
943+ bool did_defrag = false ;
944+
945+ while (true ) {
946+ kv_state = kv_self->init_batch (batch, cparams.n_ubatch , embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
947+ if (!kv_state) {
948+ return -2 ;
949+ }
950+
951+ switch (kv_state->get_status ()) {
952+ case LLAMA_MEMORY_STATUS_SUCCESS:
953+ {
954+ } break ;
955+ case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
956+ {
957+ if (!did_defrag) {
958+ did_defrag = true ;
959+
960+ kv_self->defrag_sched (-1 .0f );
961+ if (kv_self_update ()) {
962+ LLAMA_LOG_DEBUG (" %s: failed to init batch of size %d, retrying after defrag\n " , __func__, batch.n_tokens );
963+
964+ continue ;
965+ }
966+ }
967+
968+ LLAMA_LOG_WARN (" %s: failed to find KV cache slot for batch of size %d\n " , __func__, batch.n_tokens );
969+
970+ return 1 ;
971+ }
972+ case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
973+ {
974+ return -2 ;
975+ }
976+ }
977+
978+ break ;
954979 }
955980
956981 // reserve output buffer
@@ -2646,22 +2671,8 @@ int32_t llama_encode(
26462671int32_t llama_decode (
26472672 llama_context * ctx,
26482673 llama_batch batch) {
2649- int ret = ctx->decode (batch);
2650-
2651- // defrag and try again
2652- // TODO: distinguish return code when we are sure that even after defrag there is no space available
2653- if (ret == 1 ) {
2654- llama_kv_self_defrag (ctx);
2655- ret = ctx->decode (batch);
2656-
2657- if (ret == 1 ) {
2658- LLAMA_LOG_WARN (" %s: failed to find KV cache slot for batch of size %d\n " , __func__, batch.n_tokens );
2659-
2660- return ret;
2661- }
2662- }
2663-
2664- if (ret != 0 ) {
2674+ const int ret = ctx->decode (batch);
2675+ if (ret != 0 && ret != 1 ) {
26652676 LLAMA_LOG_ERROR (" %s: failed to decode, ret = %d\n " , __func__, ret);
26662677 }
26672678
0 commit comments