@@ -871,8 +871,6 @@ int llama_context::decode(llama_batch & inp_batch) {
871871 const int64_t n_tokens_all = batch.n_tokens ;
872872 const int64_t n_embd = hparams.n_embd ;
873873
874- llama_kv_cache_guard kv_guard (kv_self);
875-
876874 GGML_ASSERT ((!batch.token && batch.embd ) || (batch.token && !batch.embd )); // NOLINT
877875
878876 if (batch.token ) {
@@ -912,8 +910,6 @@ int llama_context::decode(llama_batch & inp_batch) {
912910 n_outputs_all = 1 ;
913911 }
914912
915- llama_sbatch sbatch = kv_self->sbatch_init (batch, /* logits_all */ n_outputs_all == n_tokens_all);
916-
917913 // reserve output buffer
918914 if (output_reserve (n_outputs_all) < n_outputs_all) {
919915 LLAMA_LOG_ERROR (" %s: could not reserve space for batch with %" PRId64 " outputs\n " , __func__, n_outputs_all);
@@ -923,11 +919,59 @@ int llama_context::decode(llama_batch & inp_batch) {
923919 // handle any pending defrags/shifts
924920 kv_self_update ();
925921
926- int64_t n_outputs_prev = 0 ;
922+ llama_kv_cache_guard kv_guard (kv_self);
923+
924+ // this is the sequence-aware batch that we construct based on the input batch
925+ llama_sbatch sbatch;
926+
927+ // we then split the sbatch into a set of ubatches. the split logic is delegated to the KV cache
928+ std::vector<llama_ubatch> ubatches;
929+
930+ // if we fail to find a slot for the batch, we can retry after applying a defrag.
931+ // in some cases, this can free up some space, which would be enough to fit the ubatches
932+ // ref: https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2881412612
933+ bool retry = true ;
927934
928- while (sbatch.n_tokens > 0 ) {
929- llama_ubatch ubatch = kv_self->ubatch_next (sbatch, cparams.n_ubatch , embd_pooled);
935+ while (true ) {
936+ bool success = true ;
937+
938+ sbatch = kv_self->sbatch_init (batch, /* logits_all */ n_outputs_all == n_tokens_all);
939+
940+ while (sbatch.n_tokens > 0 ) {
941+ ubatches.emplace_back (kv_self->ubatch_next (sbatch, cparams.n_ubatch , embd_pooled));
942+
943+ // find an empty KV slot that can fit the current ubatch
944+ if (!kv_self->find_slot (ubatches.back ())) {
945+ success = false ;
946+ break ;
947+ }
948+ }
949+
950+ if (success) {
951+ break ;
952+ }
930953
954+ if (!retry) {
955+ LLAMA_LOG_WARN (" %s: failed to fit the batch in the KV cache, batch size = %d\n " , __func__, (int ) n_tokens_all);
956+ return 1 ;
957+ }
958+
959+ // we failed to fit the sbatch once, and now we will try to defrag the KV cache and try to fit it again
960+ retry = false ;
961+
962+ kv_self->restore ();
963+ kv_self->defrag_sched (-1 .0f );
964+
965+ kv_self_update ();
966+
967+ ubatches.clear ();
968+ }
969+
970+ // we now have prepared the ubatches for this llama_decode and are ready to start processing
971+
972+ int64_t n_outputs_prev = 0 ;
973+
974+ for (const auto & ubatch : ubatches) {
931975 // count the outputs in this u_batch
932976 {
933977 int32_t n_outputs_new = 0 ;
@@ -945,13 +989,6 @@ int llama_context::decode(llama_batch & inp_batch) {
945989 n_outputs = n_outputs_new;
946990 }
947991
948- // find KV slot
949- if (!kv_self->find_slot (ubatch)) {
950- LLAMA_LOG_WARN (" %s: failed to find KV cache slot for ubatch of size %d\n " , __func__, ubatch.n_tokens );
951-
952- return 1 ;
953- }
954-
955992 ggml_backend_sched_reset (sched.get ());
956993 ggml_backend_sched_set_eval_callback (sched.get (), cparams.cb_eval , cparams.cb_eval_user_data );
957994
0 commit comments