@@ -871,8 +871,6 @@ int llama_context::decode(llama_batch & inp_batch) {
871871 const int64_t n_tokens_all = batch.n_tokens ;
872872 const int64_t n_embd = hparams.n_embd ;
873873
874- llama_kv_cache_guard kv_guard (kv_self);
875-
876874 GGML_ASSERT ((!batch.token && batch.embd ) || (batch.token && !batch.embd )); // NOLINT
877875
878876 if (batch.token ) {
@@ -912,8 +910,6 @@ int llama_context::decode(llama_batch & inp_batch) {
912910 n_outputs_all = 1 ;
913911 }
914912
915- llama_sbatch sbatch = kv_self->sbatch_init (batch, /* logits_all */ n_outputs_all == n_tokens_all);
916-
917913 // reserve output buffer
918914 if (output_reserve (n_outputs_all) < n_outputs_all) {
919915 LLAMA_LOG_ERROR (" %s: could not reserve space for batch with %" PRId64 " outputs\n " , __func__, n_outputs_all);
@@ -923,11 +919,61 @@ int llama_context::decode(llama_batch & inp_batch) {
923919 // handle any pending defrags/shifts
924920 kv_self_update ();
925921
926- int64_t n_outputs_prev = 0 ;
922+ llama_kv_cache_guard kv_guard (kv_self);
923+
924+ // this is the sequence-aware batch that we construct based on the input batch
925+ llama_sbatch sbatch;
926+
927+ // we then split the sbatch into a set of ubatches. the split logic is delegated to the KV cache
928+ std::vector<llama_ubatch> ubatches;
929+
930+ // if we fail to find a slot for the batch, we can retry after applying a defrag.
931+ // in some cases, this can free up some space, which would be enough to fit the ubatches
932+ // ref: https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2881412612
933+ bool retry = true ;
934+
935+ while (true ) {
936+ bool success = true ;
927937
928- while (sbatch.n_tokens > 0 ) {
929- llama_ubatch ubatch = kv_self->ubatch_next (sbatch, cparams.n_ubatch , embd_pooled);
938+ sbatch = kv_self->sbatch_init (batch, /* logits_all */ n_outputs_all == n_tokens_all);
930939
940+ while (sbatch.n_tokens > 0 ) {
941+ ubatches.emplace_back (kv_self->ubatch_next (sbatch, cparams.n_ubatch , embd_pooled));
942+
943+ const auto & ubatch = ubatches.back ();
944+
945+ // find an empty KV slot that can fit the current ubatch
946+ if (!kv_self->find_slot (ubatch)) {
947+ success = false ;
948+ break ;
949+ }
950+ }
951+
952+ if (success) {
953+ break ;
954+ }
955+
956+ if (!success && !retry) {
957+ LLAMA_LOG_WARN (" %s: failed to fit the batch in the KV cache, batch size = %d\n " , __func__, (int ) n_tokens_all);
958+ return 1 ;
959+ }
960+
961+ // we failed to fit the sbatch once, and now we will try to defrag the KV cache and try to fit it again
962+ retry = false ;
963+
964+ kv_self->restore ();
965+ kv_self->defrag_sched (-1 .0f );
966+
967+ kv_self_update ();
968+
969+ ubatches.clear ();
970+ }
971+
972+ // we now have prepared the ubatches for this llama_decode and are ready to start processing
973+
974+ int64_t n_outputs_prev = 0 ;
975+
976+ for (const auto & ubatch : ubatches) {
931977 // count the outputs in this u_batch
932978 {
933979 int32_t n_outputs_new = 0 ;
@@ -945,13 +991,6 @@ int llama_context::decode(llama_batch & inp_batch) {
945991 n_outputs = n_outputs_new;
946992 }
947993
948- // find KV slot
949- if (!kv_self->find_slot (ubatch)) {
950- LLAMA_LOG_WARN (" %s: failed to find KV cache slot for ubatch of size %d\n " , __func__, ubatch.n_tokens );
951-
952- return 1 ;
953- }
954-
955994 ggml_backend_sched_reset (sched.get ());
956995 ggml_backend_sched_set_eval_callback (sched.get (), cparams.cb_eval , cparams.cb_eval_user_data );
957996
0 commit comments