@@ -3591,27 +3591,27 @@ static bool llama_kv_cache_init(
35913591// to the first cell of the slot.
35923592static bool llama_kv_cache_find_slot(
35933593 struct llama_kv_cache & cache,
3594- const struct llama_ubatch & batch ) {
3595- const uint32_t n_tokens = batch .n_tokens;
3596- const uint32_t n_seqs = batch .n_seqs;
3597- const uint32_t n_seq_tokens = batch .n_seq_tokens;
3594+ const struct llama_ubatch & ubatch ) {
3595+ const uint32_t n_tokens = ubatch .n_tokens;
3596+ const uint32_t n_seqs = ubatch .n_seqs;
3597+ const uint32_t n_seq_tokens = ubatch .n_seq_tokens;
35983598
35993599 if (cache.recurrent) {
36003600 // For recurrent state architectures (like Mamba or RWKV),
36013601 // each cache cell can store the state for a whole sequence.
36023602 // A slot should be always be contiguous.
36033603
36043604 // can only process batches with an equal number of new tokens in each sequence
3605- GGML_ASSERT(batch .equal_seqs);
3605+ GGML_ASSERT(ubatch .equal_seqs);
36063606
36073607 int32_t min = cache.size - 1;
36083608 int32_t max = 0;
36093609
36103610 // everything should fit if all seq_ids are smaller than the max
36113611 for (uint32_t s = 0; s < n_seqs; ++s) {
3612- const uint32_t n_seq_id = batch .n_seq_id[s];
3612+ const uint32_t n_seq_id = ubatch .n_seq_id[s];
36133613 for (uint32_t j = 0; j < n_seq_id; ++j) {
3614- const llama_seq_id seq_id = batch .seq_id[s][j];
3614+ const llama_seq_id seq_id = ubatch .seq_id[s][j];
36153615
36163616 if (seq_id < 0 || (uint32_t) seq_id >= cache.size) {
36173617 // too big seq_id
@@ -3670,7 +3670,7 @@ static bool llama_kv_cache_find_slot(
36703670
36713671 // find usable cell range
36723672 for (uint32_t s = 0; s < n_seqs; ++s) {
3673- const llama_seq_id seq_id = batch .seq_id[s][0];
3673+ const llama_seq_id seq_id = ubatch .seq_id[s][0];
36743674 llama_kv_cell & seq_meta = cache.cells[seq_id];
36753675 bool has_cell = false;
36763676 if (seq_meta.tail >= 0) {
@@ -3709,7 +3709,7 @@ static bool llama_kv_cache_find_slot(
37093709 // gather and re-order
37103710 for (uint32_t s = 0; s < n_seqs; ++s) {
37113711 int32_t dst_id = s + min;
3712- int32_t src_id = cache.cells[batch .seq_id[s][0]].tail;
3712+ int32_t src_id = cache.cells[ubatch .seq_id[s][0]].tail;
37133713 if (dst_id != src_id) {
37143714 llama_kv_cell & dst_cell = cache.cells[dst_id];
37153715 llama_kv_cell & src_cell = cache.cells[src_id];
@@ -3730,20 +3730,20 @@ static bool llama_kv_cache_find_slot(
37303730
37313731 // update the pos of the used seqs
37323732 for (uint32_t s = 0; s < n_seqs; ++s) {
3733- const llama_pos last_pos = batch .pos[n_seq_tokens * s + n_seq_tokens - 1];
3733+ const llama_pos last_pos = ubatch .pos[n_seq_tokens * s + n_seq_tokens - 1];
37343734 int32_t cell_id = s + min;
37353735 llama_kv_cell & cell = cache.cells[cell_id];
37363736
37373737 if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
37383738 // What should happen when the pos backtracks or skips a value?
37393739 // Clearing the state mid-batch would require special-casing which isn't done.
37403740 LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
3741- __func__, last_pos, cell.pos, batch .seq_id[s][0], n_seq_tokens);
3741+ __func__, last_pos, cell.pos, ubatch .seq_id[s][0], n_seq_tokens);
37423742 }
37433743 cell.pos = last_pos;
37443744 cell.seq_id.clear();
3745- for (int32_t j = 0; j < batch .n_seq_id[s]; ++j) {
3746- const llama_seq_id seq_id = batch .seq_id[s][j];
3745+ for (int32_t j = 0; j < ubatch .n_seq_id[s]; ++j) {
3746+ const llama_seq_id seq_id = ubatch .seq_id[s][j];
37473747 cell.seq_id.insert(seq_id);
37483748 cache.cells[seq_id].tail = cell_id;
37493749 }
@@ -3795,10 +3795,10 @@ static bool llama_kv_cache_find_slot(
37953795 for (uint32_t s = 0; s < n_seqs; s++) {
37963796 for (uint32_t i = 0; i < n_seq_tokens; ++i) {
37973797 uint32_t k = s*n_seq_tokens + i;
3798- cache.cells[cache.head + k].pos = batch .pos[k];
3798+ cache.cells[cache.head + k].pos = ubatch .pos[k];
37993799
3800- for (int32_t j = 0; j < batch .n_seq_id[s]; j++) {
3801- cache.cells[cache.head + k].seq_id.insert(batch .seq_id[s][j]);
3800+ for (int32_t j = 0; j < ubatch .n_seq_id[s]; j++) {
3801+ cache.cells[cache.head + k].seq_id.insert(ubatch .seq_id[s][j]);
38023802 }
38033803 }
38043804 }
@@ -9178,21 +9178,21 @@ static struct ggml_tensor * llm_build_inp_embd(
91789178 struct ggml_context * ctx,
91799179 struct llama_context & lctx,
91809180 const llama_hparams & hparams,
9181- const llama_ubatch & batch ,
9181+ const llama_ubatch & ubatch ,
91829182 struct ggml_tensor * tok_embd,
91839183 const llm_build_cb & cb) {
91849184 const int64_t n_embd = hparams.n_embd;
91859185
91869186 struct ggml_tensor * inpL;
91879187
9188- if (batch .token) {
9189- lctx.inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, batch .n_tokens);
9188+ if (ubatch .token) {
9189+ lctx.inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ubatch .n_tokens);
91909190 cb(lctx.inp_tokens, "inp_tokens", -1);
91919191 ggml_set_input(lctx.inp_tokens);
91929192
91939193 inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens);
91949194 } else {
9195- lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch .n_tokens);
9195+ lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, ubatch .n_tokens);
91969196 inpL = lctx.inp_embd;
91979197 ggml_set_input(lctx.inp_embd);
91989198 }
@@ -9766,7 +9766,7 @@ static struct ggml_tensor * llm_build_copy_mask_state(
97669766static struct ggml_tensor * llm_build_mamba(
97679767 struct ggml_context * ctx,
97689768 struct llama_context & lctx,
9769- const llama_ubatch & batch ,
9769+ const llama_ubatch & ubatch ,
97709770 struct ggml_cgraph * graph,
97719771 struct ggml_tensor * cur,
97729772 struct ggml_tensor * state_copy,
@@ -9782,17 +9782,17 @@ static struct ggml_tensor * llm_build_mamba(
97829782 const int64_t d_inner = hparams.ssm_d_inner;
97839783 const int64_t d_state = hparams.ssm_d_state;
97849784 const int64_t dt_rank = hparams.ssm_dt_rank;
9785- const int64_t n_seqs = batch .n_seqs;
9785+ const int64_t n_seqs = ubatch .n_seqs;
97869786 // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
97879787 const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
97889788 // Use the same RMS norm as the final layer norm
97899789 const float norm_rms_eps = hparams.f_norm_rms_eps;
97909790
9791- const int64_t n_seq_tokens = batch .n_seq_tokens;
9791+ const int64_t n_seq_tokens = ubatch .n_seq_tokens;
97929792
97939793 GGML_ASSERT(n_seqs != 0);
9794- GGML_ASSERT(batch .equal_seqs);
9795- GGML_ASSERT(batch .n_tokens == n_seq_tokens * n_seqs);
9794+ GGML_ASSERT(ubatch .equal_seqs);
9795+ GGML_ASSERT(ubatch .n_tokens == n_seq_tokens * n_seqs);
97969796
97979797 struct ggml_tensor * conv_states_all = kv.k_l[il];
97989798 struct ggml_tensor * ssm_states_all = kv.v_l[il];
@@ -20440,10 +20440,10 @@ struct llama_data_read {
2044020440
2044120441 llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
2044220442
20443- llama_ubatch batch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
20444- batch .n_tokens = cell_count;
20445- batch .n_seq_tokens = cell_count;
20446- batch .n_seqs = 1;
20443+ llama_ubatch ubatch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
20444+ ubatch .n_tokens = cell_count;
20445+ ubatch .n_seq_tokens = cell_count;
20446+ ubatch .n_seqs = 1;
2044720447
2044820448 for (uint32_t i = 0; i < cell_count; ++i) {
2044920449 llama_pos pos;
@@ -20457,20 +20457,20 @@ struct llama_data_read {
2045720457 return false;
2045820458 }
2045920459
20460- batch .pos[i] = pos;
20460+ ubatch .pos[i] = pos;
2046120461 }
20462- batch .n_seq_id[0] = 1;
20463- batch .seq_id[0] = &dest_seq_id;
20464- if (!llama_kv_cache_find_slot(kv_self, batch )) {
20462+ ubatch .n_seq_id[0] = 1;
20463+ ubatch .seq_id[0] = &dest_seq_id;
20464+ if (!llama_kv_cache_find_slot(kv_self, ubatch )) {
2046520465 LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
2046620466 return false;
2046720467 }
2046820468
2046920469 // DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
2047020470 // Assume that this is one contiguous block of cells
2047120471 GGML_ASSERT(kv_self.head + cell_count <= kv_self.size);
20472- GGML_ASSERT(kv_self.cells[kv_self.head].pos == batch .pos[0]);
20473- GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch .pos[cell_count - 1]);
20472+ GGML_ASSERT(kv_self.cells[kv_self.head].pos == ubatch .pos[0]);
20473+ GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == ubatch .pos[cell_count - 1]);
2047420474 GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id));
2047520475 GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id));
2047620476 } else {
0 commit comments