@@ -422,9 +422,8 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
422422}
423423
424424bool llama_kv_cache_recurrent::find_slot (const llama_ubatch & ubatch) {
425- const uint32_t n_seqs = ubatch.n_seqs ;
426-
427425 const uint32_t n_seq_tokens = ubatch.n_seq_tokens ;
426+ const uint32_t n_seqs = ubatch.n_seqs ;
428427
429428 // if we have enough unused cells before the current head ->
430429 // better to start searching from the beginning of the cache, hoping to fill it
@@ -444,9 +443,11 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
444443
445444 // everything should fit if all seq_ids are smaller than the max
446445 for (uint32_t s = 0 ; s < n_seqs; ++s) {
447- const uint32_t n_seq_id = ubatch.n_seq_id [s*n_seq_tokens];
446+ const uint32_t i = s*n_seq_tokens; // first token of sequence set s
447+ const uint32_t n_seq_id = ubatch.n_seq_id [i];
448+
448449 for (uint32_t j = 0 ; j < n_seq_id; ++j) {
449- const llama_seq_id seq_id = ubatch.seq_id [s*n_seq_tokens ][j];
450+ const llama_seq_id seq_id = ubatch.seq_id [i ][j];
450451
451452 if (seq_id < 0 || (uint32_t ) seq_id >= size) {
452453 // too big seq_id
@@ -505,7 +506,9 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
505506
506507 // find usable cell range
507508 for (uint32_t s = 0 ; s < n_seqs; ++s) {
508- const llama_seq_id seq_id = ubatch.seq_id [s*n_seq_tokens][0 ];
509+ const uint32_t i = s*n_seq_tokens;
510+ const llama_seq_id seq_id = ubatch.seq_id [i][0 ];
511+
509512 kv_cell & seq_meta = cells[seq_id];
510513 bool has_cell = false ;
511514 if (seq_meta.tail >= 0 ) {
@@ -529,7 +532,7 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
529532 seq_meta.tail = next_empty_cell;
530533 // find next empty cell
531534 if (s + 1 < n_seqs) {
532- for (uint32_t i = 0 ; i < size; ++i ) {
535+ for (uint32_t j = 0 ; j < size; ++j ) {
533536 next_empty_cell += 1 ;
534537 if (next_empty_cell >= size) { next_empty_cell -= size; }
535538 kv_cell & cell = cells[next_empty_cell];
@@ -543,8 +546,9 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
543546
544547 // gather and re-order
545548 for (uint32_t s = 0 ; s < n_seqs; ++s) {
549+ const uint32_t i = s*n_seq_tokens;
546550 const int32_t dst_id = s + min;
547- const int32_t src_id = cells[ubatch.seq_id [s*n_seq_tokens ][0 ]].tail ;
551+ const int32_t src_id = cells[ubatch.seq_id [i ][0 ]].tail ;
548552 if (dst_id != src_id) {
549553 kv_cell & dst_cell = cells[dst_id];
550554 kv_cell & src_cell = cells[src_id];
@@ -554,8 +558,8 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
554558 std::swap (dst_cell.seq_id , src_cell.seq_id );
555559
556560 // swap tails
557- for (uint32_t i = 0 ; i < size; ++i ) {
558- int32_t & tail = cells[i ].tail ;
561+ for (uint32_t j = 0 ; j < size; ++j ) {
562+ int32_t & tail = cells[j ].tail ;
559563 if (tail == src_id) {
560564 tail = dst_id;
561565 } else if (tail == dst_id) {
@@ -567,20 +571,21 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
567571
568572 // update the pos of the used seqs
569573 for (uint32_t s = 0 ; s < n_seqs; ++s) {
570- const llama_pos last_pos = ubatch.pos [s*n_seq_tokens + n_seq_tokens - 1 ];
574+ const uint32_t i = s*n_seq_tokens;
575+ const llama_pos last_pos = ubatch.pos [i + n_seq_tokens - 1 ];
571576 const int32_t cell_id = s + min;
572577 kv_cell & cell = cells[cell_id];
573578
574579 if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
575580 // What should happen when the pos backtracks or skips a value?
576581 // Clearing the state mid-batch would require special-casing which isn't done.
577582 LLAMA_LOG_WARN (" %s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n " ,
578- __func__, last_pos, cell.pos , ubatch.seq_id [s*n_seq_tokens ][0 ], n_seq_tokens);
583+ __func__, last_pos, cell.pos , ubatch.seq_id [i ][0 ], n_seq_tokens);
579584 }
580585 cell.pos = last_pos;
581586 cell.seq_id .clear ();
582- for (int32_t j = 0 ; j < ubatch.n_seq_id [s*n_seq_tokens ]; ++j) {
583- const llama_seq_id seq_id = ubatch.seq_id [s*n_seq_tokens ][j];
587+ for (int32_t j = 0 ; j < ubatch.n_seq_id [i ]; ++j) {
588+ const llama_seq_id seq_id = ubatch.seq_id [i ][j];
584589 cell.seq_id .insert (seq_id);
585590 cells[seq_id].tail = cell_id;
586591 }
0 commit comments