2323#include < cstring>
2424#include < ctime>
2525#include < functional>
26+ #include < cinttypes>
2627
2728#if defined(_MSC_VER)
2829#pragma warning(disable: 4244 4267) // possible loss of data
@@ -7751,7 +7752,7 @@ static struct ggml_cgraph * llama_build_graph(
77517752// (for non-recurrent models) or cleaned (for recurrent models)
77527753//
77537754// - lctx: llama context
7754- // - batch: batch to evaluate
7755+ // - inp_batch: batch to evaluate
77557756//
77567757// return 0 on success
77577758// return positive int on warning
@@ -7774,98 +7775,34 @@ static int llama_decode_impl(
77747775
77757776 const llama_batch & batch = batch_allocr.batch ;
77767777
7777- const uint32_t n_tokens_all = batch.n_tokens ;
7778-
77797778 const auto & model = lctx.model ;
77807779 const auto & vocab = model.vocab ;
7781- const auto & hparams = model.hparams ;
77827780 const auto & cparams = lctx.cparams ;
7781+ const auto & hparams = lctx.model .hparams ;
77837782
7784- GGML_ASSERT ((!batch.token && batch.embd ) || (batch.token && !batch.embd )); // NOLINT
7785-
7786- if (batch.token ) {
7787- for (uint32_t i = 0 ; i < n_tokens_all; ++i) {
7788- if (batch.token [i] < 0 || (uint32_t ) batch.token [i] >= model.vocab .n_tokens ()) {
7789- LLAMA_LOG_ERROR (" %s: invalid token[%d] = %d\n " , __func__, i, batch.token [i]);
7790- return -1 ;
7791- }
7792- }
7793- }
7794-
7795- GGML_ASSERT (n_tokens_all <= cparams.n_batch );
7796-
7797- GGML_ASSERT ((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && " non-causal attention requires n_ubatch >= n_tokens" );
7798-
7799- if (lctx.t_compute_start_us == 0 ) {
7800- lctx.t_compute_start_us = ggml_time_us ();
7801- }
7802- lctx.n_queued_tokens += n_tokens_all;
7803-
7783+ const int32_t n_vocab = vocab.n_tokens ();
78047784 const int64_t n_embd = hparams.n_embd ;
7805- const int64_t n_vocab = vocab.n_tokens ();
7806-
7807- uint32_t n_outputs = 0 ;
7808- uint32_t n_outputs_prev = 0 ;
78097785
7810- // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
7811- const bool embd_pooled = cparams. embeddings && cparams. pooling_type != LLAMA_POOLING_TYPE_NONE ;
7786+ // TODO: try catch
7787+ auto bman = lctx. prepare_batch (batch) ;
78127788
7813- lctx.embd_seq .clear ();
7814-
7815- // count outputs
7816- if (batch.logits && !embd_pooled) {
7817- for (uint32_t i = 0 ; i < n_tokens_all; ++i) {
7818- n_outputs += batch.logits [i] != 0 ;
7819- }
7820- } else if (lctx.logits_all || embd_pooled) {
7821- n_outputs = n_tokens_all;
7822- } else {
7823- // keep last output only
7824- n_outputs = 1 ;
7825- }
7789+ const auto n_outputs_all = bman->n_outputs_all ;
78267790
78277791 // reserve output buffer
7828- if (llama_output_reserve (lctx, n_outputs) < n_outputs) {
7829- LLAMA_LOG_ERROR (" %s: could not reserve space for batch with %u outputs\n " , __func__, n_outputs);
7792+ // TODO: move to batch manager?
7793+ if (llama_output_reserve (lctx, bman->n_outputs_all ) < (size_t ) n_outputs_all) {
7794+ LLAMA_LOG_ERROR (" %s: could not reserve space for batch with %" PRId64 " outputs\n " , __func__, n_outputs_all);
78307795 return -2 ;
78317796 };
78327797
7833- const bool logits_all = n_outputs == n_tokens_all;
7834-
7835- // auto & kv_self = lctx.kv_self;
7836- // llama_kv_slot_restorer kv_slot_restorer(kv_self);
7837-
7838- // lctx.sbatch.from_batch(batch, n_embd,
7839- // /* simple_split */ !kv_self.recurrent,
7840- // /* logits_all */ logits_all);
7841-
7842- auto batch_manager = lctx.prepare_batch (batch, logits_all);
7798+ int64_t n_outputs_prev = 0 ;
78437799
78447800 while (lctx.sbatch .n_tokens > 0 ) {
7845- llama_ubatch ubatch = batch_manager->next ();
7846-
7847- const uint32_t n_tokens = ubatch.n_tokens ;
7848-
7849- // count the outputs in this u_batch
7850- {
7851- int32_t n_outputs_new = 0 ;
7852-
7853- if (n_outputs == n_tokens_all) {
7854- n_outputs_new = n_tokens;
7855- } else {
7856- GGML_ASSERT (ubatch.output );
7857- for (uint32_t i = 0 ; i < n_tokens; i++) {
7858- n_outputs_new += (int32_t ) (ubatch.output [i] != 0 );
7859- }
7860- }
7861-
7862- // needs to happen before the graph is built
7863- lctx.n_outputs = n_outputs_new;
7864- }
7801+ llama_ubatch ubatch = bman->next ();
78657802
7866- if (!batch_manager ->prepare ()) {
7803+ if (!bman ->prepare ()) {
78677804 LLAMA_LOG_ERROR (" %s: failed to prepare ubatch\n " , __func__);
7868- batch_manager ->restore ();
7805+ bman ->restore ();
78697806 return -3 ;
78707807 }
78717808
@@ -7927,9 +7864,9 @@ static int llama_decode_impl(
79277864 GGML_ASSERT (strcmp (res->name , " result_output" ) == 0 && " missing result_output tensor" );
79287865 }
79297866
7930- const auto compute_status = lctx.compute_graph (gf, n_tokens > 1 );
7867+ const auto compute_status = lctx.compute_graph (gf, ubatch. n_tokens > 1 );
79317868 if (compute_status != GGML_STATUS_SUCCESS) {
7932- batch_manager ->restore ();
7869+ bman ->restore ();
79337870 switch (compute_status) {
79347871 case GGML_STATUS_ABORTED:
79357872 return 2 ;
@@ -7941,7 +7878,7 @@ static int llama_decode_impl(
79417878 }
79427879 }
79437880
7944- batch_manager ->update ();
7881+ bman ->update ();
79457882
79467883 // plot the computation graph in dot format (for debugging purposes)
79477884 // if (n_past%100 == 0) {
@@ -7958,7 +7895,7 @@ static int llama_decode_impl(
79587895 const int32_t n_outputs_new = lctx.n_outputs ;
79597896
79607897 if (n_outputs_new) {
7961- GGML_ASSERT ( n_outputs_prev + n_outputs_new <= n_outputs );
7898+ GGML_ASSERT ( n_outputs_prev + n_outputs_new <= n_outputs_all );
79627899 GGML_ASSERT ((n_outputs_prev + n_outputs_new)*n_vocab <= (int64_t ) lctx.logits_size );
79637900 ggml_backend_tensor_get_async (backend_res, res, logits_out, 0 , n_outputs_new*n_vocab*sizeof (float ));
79647901 }
@@ -7978,7 +7915,7 @@ static int llama_decode_impl(
79787915 const int32_t n_outputs_new = lctx.n_outputs ;
79797916
79807917 if (n_outputs_new) {
7981- GGML_ASSERT ( n_outputs_prev + n_outputs_new <= n_outputs );
7918+ GGML_ASSERT ( n_outputs_prev + n_outputs_new <= n_outputs_all );
79827919 GGML_ASSERT ((n_outputs_prev + n_outputs_new)*n_embd <= (int64_t ) lctx.embd_size );
79837920 ggml_backend_tensor_get_async (backend_embd, embd, embd_out, 0 , n_outputs_new*n_embd*sizeof (float ));
79847921 }
@@ -8027,9 +7964,9 @@ static int llama_decode_impl(
80277964 {
80287965 bool sorted_output = true ;
80297966
8030- GGML_ASSERT (lctx.sbatch .out_ids .size () == n_outputs );
7967+ GGML_ASSERT (lctx.sbatch .out_ids .size () == ( size_t ) n_outputs_all );
80317968
8032- for (size_t i = 0 ; i < n_outputs ; ++i) {
7969+ for (size_t i = 0 ; i < ( size_t ) n_outputs_all ; ++i) {
80337970 size_t out_id = lctx.sbatch .out_ids [i];
80347971 lctx.output_ids [out_id] = i;
80357972 if (out_id != i) {
@@ -8043,12 +7980,12 @@ static int llama_decode_impl(
80437980 }
80447981
80457982 // set to total number of outputs in the batch, for use in llama_get_logits_ith
8046- lctx.n_outputs = n_outputs ;
7983+ lctx.n_outputs = n_outputs_all ;
80477984
80487985 // wait for the computation to finish (automatically done when obtaining the model output)
80497986 // llama_synchronize(&lctx);
80507987
8051- batch_manager ->finalize ();
7988+ bman ->finalize ();
80527989
80537990 // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
80547991 // overlap with device computation.
0 commit comments