@@ -5177,6 +5177,57 @@ struct llama_model_loader {
51775177 }
51785178};
51795179
5180+ // temporary allocate memory for the input batch if needed
5181+ static const llama_seq_id batch_default_seq_id = 0;
5182+ struct llama_batch_allocr {
5183+ std::array<llama_seq_id, 1> seq_id_0 = {batch_default_seq_id};
5184+ std::vector<llama_pos> pos;
5185+ std::vector<int32_t> n_seq_id;
5186+ std::vector<llama_seq_id *> seq_id;
5187+ std::vector<int8_t> logits;
5188+ struct llama_batch batch;
5189+ // optionally fulfill the batch returned by llama_batch_get_one
5190+ llama_batch_allocr(llama_context & ctx, struct llama_batch in_batch) {
5191+ batch = in_batch;
5192+ GGML_ASSERT(batch.n_tokens > 0);
5193+ if (!batch.pos) {
5194+ // determine the last position in KV cache
5195+ llama_pos last_pos = -1;
5196+ for (const auto & cell : ctx.kv_self.cells) {
5197+ if (cell.has_seq_id(batch_default_seq_id)) {
5198+ last_pos = std::max(last_pos, cell.pos);
5199+ }
5200+ }
5201+ last_pos++; // next position
5202+ pos.resize(batch.n_tokens);
5203+ for (int32_t i = 0; i < batch.n_tokens; i++) {
5204+ pos[i] = i+last_pos;
5205+ }
5206+ batch.pos = pos.data();
5207+ }
5208+ if (!batch.n_seq_id) {
5209+ n_seq_id.resize(batch.n_tokens);
5210+ for (int32_t i = 0; i < batch.n_tokens; i++) {
5211+ n_seq_id[i] = seq_id_0.size();
5212+ }
5213+ batch.n_seq_id = n_seq_id.data();
5214+ }
5215+ if (!batch.seq_id) {
5216+ seq_id.resize(batch.n_tokens + 1);
5217+ seq_id[batch.n_tokens] = NULL;
5218+ for (int32_t i = 0; i < batch.n_tokens; i++) {
5219+ seq_id[i] = seq_id_0.data();
5220+ }
5221+ batch.seq_id = seq_id.data();
5222+ }
5223+ if (!batch.logits) {
5224+ logits.resize(batch.n_tokens);
5225+ logits[logits.size() - 1] = true;
5226+ batch.logits = logits.data();
5227+ }
5228+ }
5229+ };
5230+
51805231template<>
51815232bool llama_model_loader::get_key(const enum llm_kv kid, enum llama_pooling_type & result, const bool required) {
51825233 uint32_t tmp;
@@ -17095,16 +17146,20 @@ static void llama_graph_compute(
1709517146//
1709617147static int llama_decode_internal(
1709717148 llama_context & lctx,
17098- llama_batch batch ) {
17149+ llama_batch inp_batch ) {
1709917150
1710017151 lctx.is_encoding = false;
17101- const uint32_t n_tokens_all = batch.n_tokens;
1710217152
17103- if (n_tokens_all == 0) {
17153+ if (inp_batch.n_tokens == 0) {
1710417154 LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
1710517155 return -1;
1710617156 }
1710717157
17158+ // temporary allocate memory for the input batch if needed
17159+ llama_batch_allocr batch_allocr(lctx, inp_batch);
17160+ const llama_batch & batch = batch_allocr.batch;
17161+ const uint32_t n_tokens_all = batch.n_tokens;
17162+
1710817163 const auto & model = lctx.model;
1710917164 const auto & hparams = model.hparams;
1711017165 const auto & cparams = lctx.cparams;
@@ -17409,17 +17464,20 @@ static int llama_decode_internal(
1740917464//
1741017465static int llama_encode_internal(
1741117466 llama_context & lctx,
17412- llama_batch batch ) {
17467+ llama_batch inp_batch ) {
1741317468
1741417469 lctx.is_encoding = true;
1741517470
17416- const uint32_t n_tokens = batch.n_tokens;
17417-
17418- if (n_tokens == 0) {
17471+ if (inp_batch.n_tokens == 0) {
1741917472 LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
1742017473 return -1;
1742117474 }
1742217475
17476+ // temporary allocate memory for the input batch if needed
17477+ llama_batch_allocr batch_allocr(lctx, inp_batch);
17478+ const llama_batch & batch = batch_allocr.batch;
17479+ const uint32_t n_tokens = batch.n_tokens;
17480+
1742317481 const auto & model = lctx.model;
1742417482 const auto & hparams = model.hparams;
1742517483 const auto & cparams = lctx.cparams;
@@ -21090,61 +21148,10 @@ void llama_batch_free(struct llama_batch batch) {
2109021148 if (batch.logits) free(batch.logits);
2109121149}
2109221150
21093- // temporary allocate memory for the input batch if needed
21094- static const llama_seq_id batch_default_seq_id = 0;
21095- struct llama_batch_allocr {
21096- std::array<llama_seq_id, 1> seq_id_0 = {batch_default_seq_id};
21097- std::vector<llama_pos> pos;
21098- std::vector<int32_t> n_seq_id;
21099- std::vector<llama_seq_id *> seq_id;
21100- std::vector<int8_t> logits;
21101- struct llama_batch batch;
21102- // optionally fulfill the batch returned by llama_batch_get_one
21103- llama_batch_allocr(struct llama_context * ctx, struct llama_batch in_batch) {
21104- batch = in_batch;
21105- if (!batch.pos) {
21106- // determine the last position in KV cache
21107- llama_pos last_pos = -1;
21108- for (const auto & cell : ctx->kv_self.cells) {
21109- if (cell.has_seq_id(batch_default_seq_id)) {
21110- last_pos = std::max(last_pos, cell.pos);
21111- }
21112- }
21113- last_pos++; // next position
21114- pos.resize(batch.n_tokens);
21115- for (int32_t i = 0; i < batch.n_tokens; i++) {
21116- pos[i] = i+last_pos;
21117- }
21118- batch.pos = pos.data();
21119- }
21120- if (!batch.n_seq_id) {
21121- n_seq_id.resize(batch.n_tokens);
21122- for (int32_t i = 0; i < batch.n_tokens; i++) {
21123- n_seq_id[i] = seq_id_0.size();
21124- }
21125- batch.n_seq_id = n_seq_id.data();
21126- }
21127- if (!batch.seq_id) {
21128- seq_id.resize(batch.n_tokens + 1);
21129- seq_id[batch.n_tokens] = NULL;
21130- for (int32_t i = 0; i < batch.n_tokens; i++) {
21131- seq_id[i] = seq_id_0.data();
21132- }
21133- batch.seq_id = seq_id.data();
21134- }
21135- if (!batch.logits) {
21136- logits.resize(batch.n_tokens);
21137- logits[logits.size() - 1] = true;
21138- batch.logits = logits.data();
21139- }
21140- }
21141- };
21142-
2114321151int32_t llama_encode(
2114421152 struct llama_context * ctx,
2114521153 struct llama_batch batch) {
21146- llama_batch_allocr batch_allocr(ctx, batch);
21147- const int ret = llama_encode_internal(*ctx, batch_allocr.batch);
21154+ const int ret = llama_encode_internal(*ctx, batch);
2114821155 if (ret != 0) {
2114921156 LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
2115021157 }
@@ -21155,8 +21162,7 @@ int32_t llama_encode(
2115521162int32_t llama_decode(
2115621163 struct llama_context * ctx,
2115721164 struct llama_batch batch) {
21158- llama_batch_allocr batch_allocr(ctx, batch);
21159- const int ret = llama_decode_internal(*ctx, batch_allocr.batch);
21165+ const int ret = llama_decode_internal(*ctx, batch);
2116021166 if (ret != 0) {
2116121167 LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
2116221168 }
0 commit comments