@@ -3591,27 +3591,27 @@ static bool llama_kv_cache_init(
35913591// to the first cell of the slot.
35923592static bool llama_kv_cache_find_slot(
35933593           struct llama_kv_cache & cache,
3594-        const struct llama_ubatch & batch ) {
3595-     const uint32_t n_tokens = batch .n_tokens;
3596-     const uint32_t n_seqs   = batch .n_seqs;
3597-     const uint32_t n_seq_tokens = batch .n_seq_tokens;
3594+        const struct llama_ubatch & ubatch ) {
3595+     const uint32_t n_tokens = ubatch .n_tokens;
3596+     const uint32_t n_seqs   = ubatch .n_seqs;
3597+     const uint32_t n_seq_tokens = ubatch .n_seq_tokens;
35983598
35993599    if (cache.recurrent) {
36003600        // For recurrent state architectures (like Mamba or RWKV),
36013601        // each cache cell can store the state for a whole sequence.
36023602        // A slot should be always be contiguous.
36033603
36043604        // can only process batches with an equal number of new tokens in each sequence
3605-         GGML_ASSERT(batch .equal_seqs);
3605+         GGML_ASSERT(ubatch .equal_seqs);
36063606
36073607        int32_t min = cache.size - 1;
36083608        int32_t max = 0;
36093609
36103610        // everything should fit if all seq_ids are smaller than the max
36113611        for (uint32_t s = 0; s < n_seqs; ++s) {
3612-             const uint32_t n_seq_id = batch .n_seq_id[s];
3612+             const uint32_t n_seq_id = ubatch .n_seq_id[s];
36133613            for (uint32_t j = 0; j < n_seq_id; ++j) {
3614-                 const llama_seq_id seq_id = batch .seq_id[s][j];
3614+                 const llama_seq_id seq_id = ubatch .seq_id[s][j];
36153615
36163616                if (seq_id < 0 || (uint32_t) seq_id >= cache.size) {
36173617                    // too big seq_id
@@ -3670,7 +3670,7 @@ static bool llama_kv_cache_find_slot(
36703670
36713671        // find usable cell range
36723672        for (uint32_t s = 0; s < n_seqs; ++s) {
3673-             const llama_seq_id seq_id = batch .seq_id[s][0];
3673+             const llama_seq_id seq_id = ubatch .seq_id[s][0];
36743674            llama_kv_cell & seq_meta = cache.cells[seq_id];
36753675            bool has_cell = false;
36763676            if (seq_meta.tail >= 0) {
@@ -3709,7 +3709,7 @@ static bool llama_kv_cache_find_slot(
37093709        // gather and re-order
37103710        for (uint32_t s = 0; s < n_seqs; ++s) {
37113711            int32_t dst_id = s + min;
3712-             int32_t src_id = cache.cells[batch .seq_id[s][0]].tail;
3712+             int32_t src_id = cache.cells[ubatch .seq_id[s][0]].tail;
37133713            if (dst_id != src_id) {
37143714                llama_kv_cell & dst_cell = cache.cells[dst_id];
37153715                llama_kv_cell & src_cell = cache.cells[src_id];
@@ -3730,20 +3730,20 @@ static bool llama_kv_cache_find_slot(
37303730
37313731        // update the pos of the used seqs
37323732        for (uint32_t s = 0; s < n_seqs; ++s) {
3733-             const llama_pos last_pos = batch .pos[n_seq_tokens * s + n_seq_tokens - 1];
3733+             const llama_pos last_pos = ubatch .pos[n_seq_tokens * s + n_seq_tokens - 1];
37343734            int32_t cell_id = s + min;
37353735            llama_kv_cell & cell = cache.cells[cell_id];
37363736
37373737            if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
37383738                // What should happen when the pos backtracks or skips a value?
37393739                // Clearing the state mid-batch would require special-casing which isn't done.
37403740                LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
3741-                     __func__, last_pos, cell.pos, batch .seq_id[s][0], n_seq_tokens);
3741+                     __func__, last_pos, cell.pos, ubatch .seq_id[s][0], n_seq_tokens);
37423742            }
37433743            cell.pos = last_pos;
37443744            cell.seq_id.clear();
3745-             for (int32_t j = 0; j < batch .n_seq_id[s]; ++j) {
3746-                 const llama_seq_id seq_id = batch .seq_id[s][j];
3745+             for (int32_t j = 0; j < ubatch .n_seq_id[s]; ++j) {
3746+                 const llama_seq_id seq_id = ubatch .seq_id[s][j];
37473747                cell.seq_id.insert(seq_id);
37483748                cache.cells[seq_id].tail = cell_id;
37493749            }
@@ -3795,10 +3795,10 @@ static bool llama_kv_cache_find_slot(
37953795    for (uint32_t s = 0; s < n_seqs; s++) {
37963796        for (uint32_t i = 0; i < n_seq_tokens; ++i) {
37973797            uint32_t k = s*n_seq_tokens + i;
3798-             cache.cells[cache.head + k].pos = batch .pos[k];
3798+             cache.cells[cache.head + k].pos = ubatch .pos[k];
37993799
3800-             for (int32_t j = 0; j < batch .n_seq_id[s]; j++) {
3801-                 cache.cells[cache.head + k].seq_id.insert(batch .seq_id[s][j]);
3800+             for (int32_t j = 0; j < ubatch .n_seq_id[s]; j++) {
3801+                 cache.cells[cache.head + k].seq_id.insert(ubatch .seq_id[s][j]);
38023802            }
38033803        }
38043804    }
@@ -9178,21 +9178,21 @@ static struct ggml_tensor * llm_build_inp_embd(
91789178        struct ggml_context * ctx,
91799179       struct llama_context & lctx,
91809180        const llama_hparams & hparams,
9181-          const llama_ubatch & batch ,
9181+          const llama_ubatch & ubatch ,
91829182         struct ggml_tensor * tok_embd,
91839183         const llm_build_cb & cb) {
91849184    const int64_t n_embd = hparams.n_embd;
91859185
91869186    struct ggml_tensor * inpL;
91879187
9188-     if (batch .token) {
9189-         lctx.inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, batch .n_tokens);
9188+     if (ubatch .token) {
9189+         lctx.inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ubatch .n_tokens);
91909190        cb(lctx.inp_tokens, "inp_tokens", -1);
91919191        ggml_set_input(lctx.inp_tokens);
91929192
91939193        inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens);
91949194    } else {
9195-        lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch .n_tokens);
9195+        lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, ubatch .n_tokens);
91969196        inpL = lctx.inp_embd;
91979197        ggml_set_input(lctx.inp_embd);
91989198    }
@@ -9766,7 +9766,7 @@ static struct ggml_tensor * llm_build_copy_mask_state(
97669766static struct ggml_tensor * llm_build_mamba(
97679767        struct ggml_context * ctx,
97689768       struct llama_context & lctx,
9769-          const llama_ubatch & batch ,
9769+          const llama_ubatch & ubatch ,
97709770         struct ggml_cgraph * graph,
97719771         struct ggml_tensor * cur,
97729772         struct ggml_tensor * state_copy,
@@ -9782,17 +9782,17 @@ static struct ggml_tensor * llm_build_mamba(
97829782    const int64_t d_inner = hparams.ssm_d_inner;
97839783    const int64_t d_state = hparams.ssm_d_state;
97849784    const int64_t dt_rank = hparams.ssm_dt_rank;
9785-     const int64_t n_seqs  = batch .n_seqs;
9785+     const int64_t n_seqs  = ubatch .n_seqs;
97869786    // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
97879787    const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
97889788    // Use the same RMS norm as the final layer norm
97899789    const float norm_rms_eps = hparams.f_norm_rms_eps;
97909790
9791-     const int64_t n_seq_tokens = batch .n_seq_tokens;
9791+     const int64_t n_seq_tokens = ubatch .n_seq_tokens;
97929792
97939793    GGML_ASSERT(n_seqs != 0);
9794-     GGML_ASSERT(batch .equal_seqs);
9795-     GGML_ASSERT(batch .n_tokens == n_seq_tokens * n_seqs);
9794+     GGML_ASSERT(ubatch .equal_seqs);
9795+     GGML_ASSERT(ubatch .n_tokens == n_seq_tokens * n_seqs);
97969796
97979797    struct ggml_tensor * conv_states_all = kv.k_l[il];
97989798    struct ggml_tensor * ssm_states_all  = kv.v_l[il];
@@ -20440,10 +20440,10 @@ struct llama_data_read {
2044020440
2044120441            llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
2044220442
20443-             llama_ubatch batch  = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
20444-             batch .n_tokens = cell_count;
20445-             batch .n_seq_tokens = cell_count;
20446-             batch .n_seqs = 1;
20443+             llama_ubatch ubatch  = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
20444+             ubatch .n_tokens = cell_count;
20445+             ubatch .n_seq_tokens = cell_count;
20446+             ubatch .n_seqs = 1;
2044720447
2044820448            for (uint32_t i = 0; i < cell_count; ++i) {
2044920449                llama_pos pos;
@@ -20457,20 +20457,20 @@ struct llama_data_read {
2045720457                    return false;
2045820458                }
2045920459
20460-                 batch .pos[i] = pos;
20460+                 ubatch .pos[i] = pos;
2046120461            }
20462-             batch .n_seq_id[0] = 1;
20463-             batch .seq_id[0] = &dest_seq_id;
20464-             if (!llama_kv_cache_find_slot(kv_self, batch )) {
20462+             ubatch .n_seq_id[0] = 1;
20463+             ubatch .seq_id[0] = &dest_seq_id;
20464+             if (!llama_kv_cache_find_slot(kv_self, ubatch )) {
2046520465                LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
2046620466                return false;
2046720467            }
2046820468
2046920469            // DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
2047020470            // Assume that this is one contiguous block of cells
2047120471            GGML_ASSERT(kv_self.head + cell_count <= kv_self.size);
20472-             GGML_ASSERT(kv_self.cells[kv_self.head].pos == batch .pos[0]);
20473-             GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch .pos[cell_count - 1]);
20472+             GGML_ASSERT(kv_self.cells[kv_self.head].pos == ubatch .pos[0]);
20473+             GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == ubatch .pos[cell_count - 1]);
2047420474            GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id));
2047520475            GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id));
2047620476        } else {
0 commit comments