@@ -119,27 +119,27 @@ bool llama_kv_cache_init(
119119
120120struct  llama_kv_cache_slot_info  llama_kv_cache_find_slot (
121121           struct  llama_kv_cache  & cache,
122-        const  struct  llama_ubatch  & batch ) {
123-     const  uint32_t  n_tokens = batch .n_tokens ;
124-     const  uint32_t  n_seqs   = batch .n_seqs ;
125-     const  uint32_t  n_seq_tokens = batch .n_seq_tokens ;
122+        const  struct  llama_ubatch  & ubatch ) {
123+     const  uint32_t  n_tokens = ubatch .n_tokens ;
124+     const  uint32_t  n_seqs   = ubatch .n_seqs ;
125+     const  uint32_t  n_seq_tokens = ubatch .n_seq_tokens ;
126126
127127    if  (cache.recurrent ) {
128128        //  For recurrent state architectures (like Mamba or RWKV),
129129        //  each cache cell can store the state for a whole sequence.
130130        //  A slot should be always be contiguous.
131131
132132        //  can only process batches with an equal number of new tokens in each sequence
133-         GGML_ASSERT (batch .equal_seqs );
133+         GGML_ASSERT (ubatch .equal_seqs );
134134
135135        int32_t  min = cache.size  - 1 ;
136136        int32_t  max = 0 ;
137137
138138        //  everything should fit if all seq_ids are smaller than the max
139139        for  (uint32_t  s = 0 ; s < n_seqs; ++s) {
140-             const  uint32_t  n_seq_id = batch .n_seq_id [s];
140+             const  uint32_t  n_seq_id = ubatch .n_seq_id [s];
141141            for  (uint32_t  j = 0 ; j < n_seq_id; ++j) {
142-                 const  llama_seq_id seq_id = batch .seq_id [s][j];
142+                 const  llama_seq_id seq_id = ubatch .seq_id [s][j];
143143
144144                if  (seq_id < 0  || (uint32_t ) seq_id >= cache.size ) {
145145                    //  too big seq_id
@@ -198,7 +198,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
198198
199199        //  find usable cell range
200200        for  (uint32_t  s = 0 ; s < n_seqs; ++s) {
201-             const  llama_seq_id seq_id = batch .seq_id [s][0 ];
201+             const  llama_seq_id seq_id = ubatch .seq_id [s][0 ];
202202            llama_kv_cell & seq_meta = cache.cells [seq_id];
203203            bool  has_cell = false ;
204204            if  (seq_meta.tail  >= 0 ) {
@@ -237,7 +237,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
237237        //  gather and re-order
238238        for  (uint32_t  s = 0 ; s < n_seqs; ++s) {
239239            int32_t  dst_id = s + min;
240-             int32_t  src_id = cache.cells [batch .seq_id [s][0 ]].tail ;
240+             int32_t  src_id = cache.cells [ubatch .seq_id [s][0 ]].tail ;
241241            if  (dst_id != src_id) {
242242                llama_kv_cell & dst_cell = cache.cells [dst_id];
243243                llama_kv_cell & src_cell = cache.cells [src_id];
@@ -258,20 +258,20 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
258258
259259        //  update the pos of the used seqs
260260        for  (uint32_t  s = 0 ; s < n_seqs; ++s) {
261-             const  llama_pos last_pos = batch .pos [n_seq_tokens * s + n_seq_tokens - 1 ];
261+             const  llama_pos last_pos = ubatch .pos [n_seq_tokens * s + n_seq_tokens - 1 ];
262262            int32_t  cell_id = s + min;
263263            llama_kv_cell & cell = cache.cells [cell_id];
264264
265265            if  (cell.pos  >= 0  && last_pos != cell.pos  + (llama_pos) n_seq_tokens) {
266266                //  What should happen when the pos backtracks or skips a value?
267267                //  Clearing the state mid-batch would require special-casing which isn't done.
268268                LLAMA_LOG_WARN (" %s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n "  ,
269-                     __func__, last_pos, cell.pos , batch .seq_id [s][0 ], n_seq_tokens);
269+                     __func__, last_pos, cell.pos , ubatch .seq_id [s][0 ], n_seq_tokens);
270270            }
271271            cell.pos  = last_pos;
272272            cell.seq_id .clear ();
273-             for  (int32_t  j = 0 ; j < batch .n_seq_id [s]; ++j) {
274-                 const  llama_seq_id seq_id = batch .seq_id [s][j];
273+             for  (int32_t  j = 0 ; j < ubatch .n_seq_id [s]; ++j) {
274+                 const  llama_seq_id seq_id = ubatch .seq_id [s][j];
275275                cell.seq_id .insert (seq_id);
276276                cache.cells [seq_id].tail  = cell_id;
277277            }
@@ -325,10 +325,10 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
325325    for  (uint32_t  s = 0 ; s < n_seqs; s++) {
326326        for  (uint32_t  i = 0 ; i < n_seq_tokens; ++i) {
327327            uint32_t  k = s*n_seq_tokens + i;
328-             cache.cells [cache.head  + k].pos  = batch .pos [k];
328+             cache.cells [cache.head  + k].pos  = ubatch .pos [k];
329329
330-             for  (int32_t  j = 0 ; j < batch .n_seq_id [s]; j++) {
331-                 cache.cells [cache.head  + k].seq_id .insert (batch .seq_id [s][j]);
330+             for  (int32_t  j = 0 ; j < ubatch .n_seq_id [s]; j++) {
331+                 cache.cells [cache.head  + k].seq_id .insert (ubatch .seq_id [s][j]);
332332            }
333333        }
334334    }
0 commit comments