@@ -119,27 +119,27 @@ bool llama_kv_cache_init(
119119
120120struct llama_kv_cache_slot_info llama_kv_cache_find_slot (
121121 struct llama_kv_cache & cache,
122- const struct llama_ubatch & batch ) {
123- const uint32_t n_tokens = batch .n_tokens ;
124- const uint32_t n_seqs = batch .n_seqs ;
125- const uint32_t n_seq_tokens = batch .n_seq_tokens ;
122+ const struct llama_ubatch & ubatch ) {
123+ const uint32_t n_tokens = ubatch .n_tokens ;
124+ const uint32_t n_seqs = ubatch .n_seqs ;
125+ const uint32_t n_seq_tokens = ubatch .n_seq_tokens ;
126126
127127 if (cache.recurrent ) {
128128 // For recurrent state architectures (like Mamba or RWKV),
129129 // each cache cell can store the state for a whole sequence.
130130 // A slot should be always be contiguous.
131131
132132 // can only process batches with an equal number of new tokens in each sequence
133- GGML_ASSERT (batch .equal_seqs );
133+ GGML_ASSERT (ubatch .equal_seqs );
134134
135135 int32_t min = cache.size - 1 ;
136136 int32_t max = 0 ;
137137
138138 // everything should fit if all seq_ids are smaller than the max
139139 for (uint32_t s = 0 ; s < n_seqs; ++s) {
140- const uint32_t n_seq_id = batch .n_seq_id [s];
140+ const uint32_t n_seq_id = ubatch .n_seq_id [s];
141141 for (uint32_t j = 0 ; j < n_seq_id; ++j) {
142- const llama_seq_id seq_id = batch .seq_id [s][j];
142+ const llama_seq_id seq_id = ubatch .seq_id [s][j];
143143
144144 if (seq_id < 0 || (uint32_t ) seq_id >= cache.size ) {
145145 // too big seq_id
@@ -198,7 +198,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
198198
199199 // find usable cell range
200200 for (uint32_t s = 0 ; s < n_seqs; ++s) {
201- const llama_seq_id seq_id = batch .seq_id [s][0 ];
201+ const llama_seq_id seq_id = ubatch .seq_id [s][0 ];
202202 llama_kv_cell & seq_meta = cache.cells [seq_id];
203203 bool has_cell = false ;
204204 if (seq_meta.tail >= 0 ) {
@@ -237,7 +237,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
237237 // gather and re-order
238238 for (uint32_t s = 0 ; s < n_seqs; ++s) {
239239 int32_t dst_id = s + min;
240- int32_t src_id = cache.cells [batch .seq_id [s][0 ]].tail ;
240+ int32_t src_id = cache.cells [ubatch .seq_id [s][0 ]].tail ;
241241 if (dst_id != src_id) {
242242 llama_kv_cell & dst_cell = cache.cells [dst_id];
243243 llama_kv_cell & src_cell = cache.cells [src_id];
@@ -258,20 +258,20 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
258258
259259 // update the pos of the used seqs
260260 for (uint32_t s = 0 ; s < n_seqs; ++s) {
261- const llama_pos last_pos = batch .pos [n_seq_tokens * s + n_seq_tokens - 1 ];
261+ const llama_pos last_pos = ubatch .pos [n_seq_tokens * s + n_seq_tokens - 1 ];
262262 int32_t cell_id = s + min;
263263 llama_kv_cell & cell = cache.cells [cell_id];
264264
265265 if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
266266 // What should happen when the pos backtracks or skips a value?
267267 // Clearing the state mid-batch would require special-casing which isn't done.
268268 LLAMA_LOG_WARN (" %s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n " ,
269- __func__, last_pos, cell.pos , batch .seq_id [s][0 ], n_seq_tokens);
269+ __func__, last_pos, cell.pos , ubatch .seq_id [s][0 ], n_seq_tokens);
270270 }
271271 cell.pos = last_pos;
272272 cell.seq_id .clear ();
273- for (int32_t j = 0 ; j < batch .n_seq_id [s]; ++j) {
274- const llama_seq_id seq_id = batch .seq_id [s][j];
273+ for (int32_t j = 0 ; j < ubatch .n_seq_id [s]; ++j) {
274+ const llama_seq_id seq_id = ubatch .seq_id [s][j];
275275 cell.seq_id .insert (seq_id);
276276 cache.cells [seq_id].tail = cell_id;
277277 }
@@ -325,10 +325,10 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
325325 for (uint32_t s = 0 ; s < n_seqs; s++) {
326326 for (uint32_t i = 0 ; i < n_seq_tokens; ++i) {
327327 uint32_t k = s*n_seq_tokens + i;
328- cache.cells [cache.head + k].pos = batch .pos [k];
328+ cache.cells [cache.head + k].pos = ubatch .pos [k];
329329
330- for (int32_t j = 0 ; j < batch .n_seq_id [s]; j++) {
331- cache.cells [cache.head + k].seq_id .insert (batch .seq_id [s][j]);
330+ for (int32_t j = 0 ; j < ubatch .n_seq_id [s]; j++) {
331+ cache.cells [cache.head + k].seq_id .insert (ubatch .seq_id [s][j]);
332332 }
333333 }
334334 }
0 commit comments