@@ -3785,27 +3785,27 @@ static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false};
37853785// to the first cell of the slot.
37863786static struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
37873787 struct llama_kv_cache & cache,
3788- const struct llama_ubatch & batch ) {
3789- const uint32_t n_tokens = batch .n_tokens;
3790- const uint32_t n_seqs = batch .n_seqs;
3791- const uint32_t n_seq_tokens = batch .n_seq_tokens;
3788+ const struct llama_ubatch & ubatch ) {
3789+ const uint32_t n_tokens = ubatch .n_tokens;
3790+ const uint32_t n_seqs = ubatch .n_seqs;
3791+ const uint32_t n_seq_tokens = ubatch .n_seq_tokens;
37923792
37933793 if (cache.recurrent) {
37943794 // For recurrent state architectures (like Mamba or RWKV),
37953795 // each cache cell can store the state for a whole sequence.
37963796 // A slot should be always be contiguous.
37973797
37983798 // can only process batches with an equal number of new tokens in each sequence
3799- GGML_ASSERT(batch .equal_seqs);
3799+ GGML_ASSERT(ubatch .equal_seqs);
38003800
38013801 int32_t min = cache.size - 1;
38023802 int32_t max = 0;
38033803
38043804 // everything should fit if all seq_ids are smaller than the max
38053805 for (uint32_t s = 0; s < n_seqs; ++s) {
3806- const uint32_t n_seq_id = batch .n_seq_id[s];
3806+ const uint32_t n_seq_id = ubatch .n_seq_id[s];
38073807 for (uint32_t j = 0; j < n_seq_id; ++j) {
3808- const llama_seq_id seq_id = batch .seq_id[s][j];
3808+ const llama_seq_id seq_id = ubatch .seq_id[s][j];
38093809
38103810 if (seq_id < 0 || (uint32_t) seq_id >= cache.size) {
38113811 // too big seq_id
@@ -3864,7 +3864,7 @@ static struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
38643864
38653865 // find usable cell range
38663866 for (uint32_t s = 0; s < n_seqs; ++s) {
3867- const llama_seq_id seq_id = batch .seq_id[s][0];
3867+ const llama_seq_id seq_id = ubatch .seq_id[s][0];
38683868 llama_kv_cell & seq_meta = cache.cells[seq_id];
38693869 bool has_cell = false;
38703870 if (seq_meta.tail >= 0) {
@@ -3903,7 +3903,7 @@ static struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
39033903 // gather and re-order
39043904 for (uint32_t s = 0; s < n_seqs; ++s) {
39053905 int32_t dst_id = s + min;
3906- int32_t src_id = cache.cells[batch .seq_id[s][0]].tail;
3906+ int32_t src_id = cache.cells[ubatch .seq_id[s][0]].tail;
39073907 if (dst_id != src_id) {
39083908 llama_kv_cell & dst_cell = cache.cells[dst_id];
39093909 llama_kv_cell & src_cell = cache.cells[src_id];
@@ -3924,20 +3924,20 @@ static struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
39243924
39253925 // update the pos of the used seqs
39263926 for (uint32_t s = 0; s < n_seqs; ++s) {
3927- const llama_pos last_pos = batch .pos[n_seq_tokens * s + n_seq_tokens - 1];
3927+ const llama_pos last_pos = ubatch .pos[n_seq_tokens * s + n_seq_tokens - 1];
39283928 int32_t cell_id = s + min;
39293929 llama_kv_cell & cell = cache.cells[cell_id];
39303930
39313931 if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
39323932 // What should happen when the pos backtracks or skips a value?
39333933 // Clearing the state mid-batch would require special-casing which isn't done.
39343934 LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
3935- __func__, last_pos, cell.pos, batch .seq_id[s][0], n_seq_tokens);
3935+ __func__, last_pos, cell.pos, ubatch .seq_id[s][0], n_seq_tokens);
39363936 }
39373937 cell.pos = last_pos;
39383938 cell.seq_id.clear();
3939- for (int32_t j = 0; j < batch .n_seq_id[s]; ++j) {
3940- const llama_seq_id seq_id = batch .seq_id[s][j];
3939+ for (int32_t j = 0; j < ubatch .n_seq_id[s]; ++j) {
3940+ const llama_seq_id seq_id = ubatch .seq_id[s][j];
39413941 cell.seq_id.insert(seq_id);
39423942 cache.cells[seq_id].tail = cell_id;
39433943 }
@@ -3991,10 +3991,10 @@ static struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
39913991 for (uint32_t s = 0; s < n_seqs; s++) {
39923992 for (uint32_t i = 0; i < n_seq_tokens; ++i) {
39933993 uint32_t k = s*n_seq_tokens + i;
3994- cache.cells[cache.head + k].pos = batch .pos[k];
3994+ cache.cells[cache.head + k].pos = ubatch .pos[k];
39953995
3996- for (int32_t j = 0; j < batch .n_seq_id[s]; j++) {
3997- cache.cells[cache.head + k].seq_id.insert(batch .seq_id[s][j]);
3996+ for (int32_t j = 0; j < ubatch .n_seq_id[s]; j++) {
3997+ cache.cells[cache.head + k].seq_id.insert(ubatch .seq_id[s][j]);
39983998 }
39993999 }
40004000 }
@@ -9931,21 +9931,21 @@ static struct ggml_tensor * llm_build_inp_embd(
99319931 struct ggml_context * ctx,
99329932 struct llama_context & lctx,
99339933 const llama_hparams & hparams,
9934- const llama_ubatch & batch ,
9934+ const llama_ubatch & ubatch ,
99359935 struct ggml_tensor * tok_embd,
99369936 const llm_build_cb & cb) {
99379937 const int64_t n_embd = hparams.n_embd;
99389938
99399939 struct ggml_tensor * inpL;
99409940
9941- if (batch .token) {
9942- lctx.inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, batch .n_tokens);
9941+ if (ubatch .token) {
9942+ lctx.inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ubatch .n_tokens);
99439943 cb(lctx.inp_tokens, "inp_tokens", -1);
99449944 ggml_set_input(lctx.inp_tokens);
99459945
99469946 inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens);
99479947 } else {
9948- lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch .n_tokens);
9948+ lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, ubatch .n_tokens);
99499949 inpL = lctx.inp_embd;
99509950 ggml_set_input(lctx.inp_embd);
99519951 }
@@ -10518,7 +10518,7 @@ static struct ggml_tensor * llm_build_copy_mask_state(
1051810518static struct ggml_tensor * llm_build_mamba(
1051910519 struct ggml_context * ctx,
1052010520 struct llama_context & lctx,
10521- const llama_ubatch & batch ,
10521+ const llama_ubatch & ubatch ,
1052210522 struct ggml_cgraph * graph,
1052310523 struct ggml_tensor * cur,
1052410524 struct ggml_tensor * state_copy,
@@ -10534,17 +10534,17 @@ static struct ggml_tensor * llm_build_mamba(
1053410534 const int64_t d_inner = hparams.ssm_d_inner;
1053510535 const int64_t d_state = hparams.ssm_d_state;
1053610536 const int64_t dt_rank = hparams.ssm_dt_rank;
10537- const int64_t n_seqs = batch .n_seqs;
10537+ const int64_t n_seqs = ubatch .n_seqs;
1053810538 // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
1053910539 const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
1054010540 // Use the same RMS norm as the final layer norm
1054110541 const float norm_rms_eps = hparams.f_norm_rms_eps;
1054210542
10543- const int64_t n_seq_tokens = batch .n_seq_tokens;
10543+ const int64_t n_seq_tokens = ubatch .n_seq_tokens;
1054410544
1054510545 GGML_ASSERT(n_seqs != 0);
10546- GGML_ASSERT(batch .equal_seqs);
10547- GGML_ASSERT(batch .n_tokens == n_seq_tokens * n_seqs);
10546+ GGML_ASSERT(ubatch .equal_seqs);
10547+ GGML_ASSERT(ubatch .n_tokens == n_seq_tokens * n_seqs);
1054810548
1054910549 struct ggml_tensor * conv_states_all = kv.k_l[il];
1055010550 struct ggml_tensor * ssm_states_all = kv.v_l[il];
@@ -21828,10 +21828,10 @@ struct llama_data_read {
2182821828
2182921829 llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
2183021830
21831- llama_ubatch batch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
21832- batch .n_tokens = cell_count;
21833- batch .n_seq_tokens = cell_count;
21834- batch .n_seqs = 1;
21831+ llama_ubatch ubatch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
21832+ ubatch .n_tokens = cell_count;
21833+ ubatch .n_seq_tokens = cell_count;
21834+ ubatch .n_seqs = 1;
2183521835
2183621836 for (uint32_t i = 0; i < cell_count; ++i) {
2183721837 llama_pos pos;
@@ -21845,20 +21845,20 @@ struct llama_data_read {
2184521845 return false;
2184621846 }
2184721847
21848- batch .pos[i] = pos;
21848+ ubatch .pos[i] = pos;
2184921849 }
21850- batch .n_seq_id[0] = 1;
21851- batch .seq_id[0] = &dest_seq_id;
21852- if (!llama_kv_cache_find_slot(kv_self, batch )) {
21850+ ubatch .n_seq_id[0] = 1;
21851+ ubatch .seq_id[0] = &dest_seq_id;
21852+ if (!llama_kv_cache_find_slot(kv_self, ubatch )) {
2185321853 LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
2185421854 return false;
2185521855 }
2185621856
2185721857 // DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
2185821858 // Assume that this is one contiguous block of cells
2185921859 GGML_ASSERT(kv_self.head + cell_count <= kv_self.size);
21860- GGML_ASSERT(kv_self.cells[kv_self.head].pos == batch .pos[0]);
21861- GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch .pos[cell_count - 1]);
21860+ GGML_ASSERT(kv_self.cells[kv_self.head].pos == ubatch .pos[0]);
21861+ GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == ubatch .pos[cell_count - 1]);
2186221862 GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id));
2186321863 GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id));
2186421864 } else {
0 commit comments