@@ -8432,74 +8432,33 @@ static enum ggml_status llama_graph_compute(
84328432 return status;
84338433}
84348434
8435- // decode a batch of tokens by evaluating the transformer
8436- // in case of unsuccessful decoding (error or warning),
8437- // the kv_cache state will be returned to its original state
8438- // (for non-recurrent models) or cleaned (for recurrent models)
8439- //
8440- // - lctx: llama context
8441- // - batch: batch to evaluate
8442- //
8443- // return 0 on success
8444- // return positive int on warning
8445- // return negative int on error
8446- //
8447- static int llama_decode_impl (
8448- llama_context & lctx,
8449- llama_batch inp_batch) {
8450-
8451- lctx.is_encoding = false ;
8452-
8453- if (inp_batch.n_tokens == 0 ) {
8454- LLAMA_LOG_ERROR (" %s: n_tokens == 0\n " , __func__);
8455- return -1 ;
8456- }
8457-
8458- // temporary allocate memory for the input batch if needed
8459- llama_batch_allocr batch_allocr (inp_batch, inp_batch.pos ? -1 : lctx.kv_self .max_pos () + 1 );
8460-
8461- const llama_batch & batch = batch_allocr.batch ;
8462- const uint32_t n_tokens_all = batch.n_tokens ;
8463-
8435+ static int llama_prepare_sbatch (
8436+ llama_context & lctx,
8437+ const llama_batch & batch,
8438+ uint32_t & n_outputs) {
84648439 const auto & model = lctx.model ;
8465- const auto & vocab = model.vocab ;
84668440 const auto & hparams = model.hparams ;
84678441 const auto & cparams = lctx.cparams ;
84688442
8469- GGML_ASSERT ((!batch.token && batch.embd ) || (batch.token && !batch.embd )); // NOLINT
8443+ const uint32_t n_tokens_all = batch.n_tokens ;
8444+ const int64_t n_embd = hparams.n_embd ;
8445+
8446+ // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
8447+ const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
84708448
8449+ GGML_ASSERT ((!batch.token && batch.embd ) || (batch.token && !batch.embd )); // NOLINT
84718450 if (batch.token ) {
84728451 for (uint32_t i = 0 ; i < n_tokens_all; ++i) {
8473- if (batch.token [i] < 0 || ( uint32_t ) batch.token [i] >= model.vocab .n_tokens ()) {
8452+ if (batch.token [i] < 0 || uint32_t ( batch.token [i]) >= model.vocab .n_tokens ()) {
84748453 LLAMA_LOG_ERROR (" %s: invalid token[%d] = %d\n " , __func__, i, batch.token [i]);
84758454 return -1 ;
84768455 }
84778456 }
84788457 }
8479-
84808458 GGML_ASSERT (n_tokens_all <= cparams.n_batch );
8481-
84828459 GGML_ASSERT ((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && " non-causal attention requires n_ubatch >= n_tokens" );
84838460
8484- if (lctx.t_compute_start_us == 0 ) {
8485- lctx.t_compute_start_us = ggml_time_us ();
8486- }
84878461 lctx.n_queued_tokens += n_tokens_all;
8488-
8489- auto & kv_self = lctx.kv_self ;
8490- llama_kv_slot_restorer kv_slot_restorer (kv_self);
8491-
8492- const int64_t n_embd = hparams.n_embd ;
8493- const int64_t n_vocab = vocab.n_tokens ();
8494-
8495- uint32_t n_outputs = 0 ;
8496- uint32_t n_outputs_prev = 0 ;
8497-
8498- const auto n_ubatch = cparams.n_ubatch ;
8499-
8500- // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
8501- const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
8502-
85038462 lctx.embd_seq .clear ();
85048463
85058464 // count outputs
@@ -8515,7 +8474,7 @@ static int llama_decode_impl(
85158474 }
85168475
85178476 lctx.sbatch .from_batch (batch, n_embd,
8518- /* simple_split */ !kv_self.recurrent ,
8477+ /* simple_split */ !lctx. kv_self .recurrent ,
85198478 /* logits_all */ n_outputs == n_tokens_all);
85208479
85218480 // reserve output buffer
@@ -8524,70 +8483,148 @@ static int llama_decode_impl(
85248483 return -2 ;
85258484 };
85268485
8527- while (lctx.sbatch .n_tokens > 0 ) {
8528- llama_ubatch ubatch;
8529- if (kv_self.recurrent ) {
8530- if (embd_pooled) {
8531- // Pooled embeddings cannot be split across ubatches (yet)
8532- ubatch = lctx.sbatch .split_seq (n_ubatch);
8533- } else {
8534- // recurrent model architectures are easier to implement
8535- // with equal-length sequences
8536- ubatch = lctx.sbatch .split_equal (n_ubatch);
8537- }
8486+ return 0 ;
8487+ }
8488+
8489+ static int llama_prepare_ubatch (
8490+ llama_context & lctx,
8491+ llama_kv_slot_restorer & kv_slot_restorer,
8492+ llama_ubatch & ubatch,
8493+ const uint32_t n_outputs,
8494+ const uint32_t n_tokens_all) {
8495+ GGML_ASSERT (lctx.sbatch .n_tokens > 0 );
8496+
8497+ auto & kv_self = lctx.kv_self ;
8498+ const auto & cparams = lctx.cparams ;
8499+ const auto & hparams = lctx.model .hparams ;
8500+
8501+ // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
8502+ const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
8503+
8504+ if (lctx.kv_self .recurrent ) {
8505+ if (embd_pooled) {
8506+ // Pooled embeddings cannot be split across ubatches (yet)
8507+ ubatch = lctx.sbatch .split_seq (cparams.n_ubatch );
85388508 } else {
8539- ubatch = lctx.sbatch .split_simple (n_ubatch);
8509+ // recurrent model architectures are easier to implement
8510+ // with equal-length sequences
8511+ ubatch = lctx.sbatch .split_equal (cparams.n_ubatch );
85408512 }
8541- const uint32_t n_tokens = ubatch.n_tokens ;
8513+ } else {
8514+ ubatch = lctx.sbatch .split_simple (cparams.n_ubatch );
8515+ }
85428516
8543- // count the outputs in this u_batch
8544- {
8545- int32_t n_outputs_new = 0 ;
8517+ // count the outputs in this u_batch
8518+ {
8519+ int32_t n_outputs_new = 0 ;
85468520
8547- if (n_outputs == n_tokens_all) {
8548- n_outputs_new = n_tokens;
8549- } else {
8550- GGML_ASSERT (ubatch.output );
8551- for (uint32_t i = 0 ; i < n_tokens; i++) {
8552- n_outputs_new += (int32_t ) (ubatch.output [i] != 0 );
8553- }
8521+ if (n_outputs == n_tokens_all) {
8522+ n_outputs_new = ubatch.n_tokens ;
8523+ } else {
8524+ GGML_ASSERT (ubatch.output );
8525+ for (uint32_t i = 0 ; i < ubatch.n_tokens ; i++) {
8526+ n_outputs_new += int32_t (ubatch.output [i] != 0 );
85548527 }
8528+ }
8529+
8530+ // needs to happen before the graph is built
8531+ lctx.n_outputs = n_outputs_new;
8532+ }
8533+
8534+ // non-causal masks do not use the KV cache
8535+ if (hparams.causal_attn ) {
8536+ llama_kv_cache_update (&lctx);
85558537
8556- // needs to happen before the graph is built
8557- lctx.n_outputs = n_outputs_new;
8538+ // if we have enough unused cells before the current head ->
8539+ // better to start searching from the beginning of the cache, hoping to fill it
8540+ if (kv_self.head > kv_self.used + 2 *ubatch.n_tokens ) {
8541+ kv_self.head = 0 ;
85588542 }
85598543
8560- int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch ;
8561- ggml_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch ;
8544+ const auto slot = llama_kv_cache_find_slot (kv_self, ubatch);
8545+ if (!slot) {
8546+ return 1 ;
8547+ }
8548+ kv_slot_restorer.save (slot);
85628549
8563- GGML_ASSERT (n_threads > 0 );
8550+ if (!kv_self.recurrent ) {
8551+ // a heuristic, to avoid attending the full cache if it is not yet utilized
8552+ // after enough generations, the benefit from this heuristic disappears
8553+ // if we start defragmenting the cache, the benefit from this will be more important
8554+ const uint32_t pad = llama_kv_cache_get_padding (cparams);
8555+ kv_self.n = std::min (kv_self.size , std::max (pad, GGML_PAD (llama_kv_cache_cell_max (kv_self), pad)));
8556+ // kv_self.n = llama_kv_cache_cell_max(kv_self);
8557+ }
8558+ }
85648559
8565- // non-causal masks do not use the KV cache
8566- if (hparams.causal_attn ) {
8567- llama_kv_cache_update (&lctx);
8560+ return 0 ;
8561+ }
85688562
8569- // if we have enough unused cells before the current head ->
8570- // better to start searching from the beginning of the cache, hoping to fill it
8571- if (kv_self.head > kv_self.used + 2 *n_tokens) {
8572- kv_self.head = 0 ;
8573- }
8563+ // decode a batch of tokens by evaluating the transformer
8564+ // in case of unsuccessful decoding (error or warning),
8565+ // the kv_cache state will be returned to its original state
8566+ // (for non-recurrent models) or cleaned (for recurrent models)
8567+ //
8568+ // - lctx: llama context
8569+ // - inp_batch: batch to evaluate
8570+ //
8571+ // return 0 on success
8572+ // return positive int on warning
8573+ // return negative int on error
8574+ //
8575+ static int llama_decode_impl (
8576+ llama_context & lctx,
8577+ llama_batch inp_batch) {
85748578
8575- const auto slot = llama_kv_cache_find_slot (kv_self, ubatch);
8576- if (!slot) {
8577- return 1 ;
8578- }
8579- kv_slot_restorer.save (slot);
8579+ lctx.is_encoding = false ;
85808580
8581- if (!kv_self.recurrent ) {
8582- // a heuristic, to avoid attending the full cache if it is not yet utilized
8583- // after enough generations, the benefit from this heuristic disappears
8584- // if we start defragmenting the cache, the benefit from this will be more important
8585- const uint32_t pad = llama_kv_cache_get_padding (cparams);
8586- kv_self.n = std::min (kv_self.size , std::max (pad, GGML_PAD (llama_kv_cache_cell_max (kv_self), pad)));
8587- // kv_self.n = llama_kv_cache_cell_max(kv_self);
8581+ if (inp_batch.n_tokens == 0 ) {
8582+ LLAMA_LOG_ERROR (" %s: n_tokens == 0\n " , __func__);
8583+ return -1 ;
8584+ }
8585+
8586+ // temporarily allocate memory for the input batch if needed
8587+ llama_batch_allocr batch_allocr (inp_batch, inp_batch.pos ? -1 : lctx.kv_self .max_pos () + 1 );
8588+ const llama_batch & batch = batch_allocr.batch ;
8589+
8590+ const auto & model = lctx.model ;
8591+ const auto & vocab = model.vocab ;
8592+ const auto & hparams = model.hparams ;
8593+ const auto & cparams = lctx.cparams ;
8594+
8595+ if (lctx.t_compute_start_us == 0 ) {
8596+ lctx.t_compute_start_us = ggml_time_us ();
8597+ }
8598+ auto & kv_self = lctx.kv_self ;
8599+ llama_kv_slot_restorer kv_slot_restorer (kv_self);
8600+
8601+ const int64_t n_embd = hparams.n_embd ;
8602+ const int64_t n_vocab = vocab.n_tokens ();
8603+
8604+ uint32_t n_outputs = 0 ;
8605+ uint32_t n_outputs_prev = 0 ;
8606+
8607+ {
8608+ const int ret = llama_prepare_sbatch (lctx, batch, n_outputs);
8609+ if (ret != 0 ) {
8610+ return ret;
8611+ }
8612+ }
8613+
8614+ while (lctx.sbatch .n_tokens > 0 ) {
8615+ llama_ubatch ubatch;
8616+ {
8617+ const int ret = llama_prepare_ubatch (lctx, kv_slot_restorer, ubatch, n_outputs, batch.n_tokens );
8618+ if (ret != 0 ) {
8619+ return ret;
85888620 }
85898621 }
85908622
8623+ const int n_threads = ubatch.n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch ;
8624+ ggml_threadpool_t threadpool = ubatch.n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch ;
8625+
8626+ GGML_ASSERT (n_threads > 0 );
8627+
85918628 // printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
85928629
85938630 ggml_backend_sched_reset (lctx.sched .get ());
@@ -8640,7 +8677,7 @@ static int llama_decode_impl(
86408677
86418678 // update the kv ring buffer
86428679 {
8643- kv_self.head += n_tokens;
8680+ kv_self.head += ubatch. n_tokens ;
86448681
86458682 // Ensure kv cache head points to a valid index.
86468683 if (kv_self.head >= kv_self.size ) {
0 commit comments