@@ -406,21 +406,12 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
406406
407407 bool success = true ;
408408
409- // TODO: here we have to verify that all ubatches can fit in the cells
410- // however, the current implementation is broken because it relies on s_copy() and s_mask() to update the cells
411- // during the compute of each ubatch. to reproduce, uncomment the following loop and run:
412- //
413- // $ llama-parallel -m ./mamba-130m/ggml-model-f16.gguf -np 5 -ns 8
414- //
415- // recovery from failures when the batch does not fit in the KV cache will not work correctly until this is fixed
416- //
417- GGML_UNUSED (ubatches);
418- // for (const auto & ubatch : ubatches) {
419- // if (!find_slot(ubatch)) {
420- // success = false;
421- // break;
422- // }
423- // }
409+ for (const auto & ubatch : ubatches) {
410+ if (!find_slot (ubatch)) {
411+ success = false ;
412+ break ;
413+ }
414+ }
424415
425416 // restore the original state
426417 cells = std::move (org_cells);
@@ -431,14 +422,13 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
431422}
432423
433424bool llama_kv_cache_recurrent::find_slot (const llama_ubatch & ubatch) {
434- const uint32_t n_tokens = ubatch.n_tokens ;
435- const uint32_t n_seqs = ubatch.n_seqs ;
425+ const uint32_t n_seqs = ubatch.n_seqs ;
436426
437427 const uint32_t n_seq_tokens = ubatch.n_seq_tokens ;
438428
439429 // if we have enough unused cells before the current head ->
440430 // better to start searching from the beginning of the cache, hoping to fill it
441- if (head > used + 2 *n_tokens ) {
431+ if (head > used + 2 *n_seqs ) {
442432 head = 0 ;
443433 }
444434
@@ -534,16 +524,16 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
534524 empty_cell.src = orig_cell.src ;
535525 orig_cell.seq_id .erase (seq_id);
536526 empty_cell.seq_id .insert (seq_id); // will be overwritten
527+ GGML_ASSERT (!orig_cell.is_empty ()); // has at least one remaining seq_id
537528 }
538529 seq_meta.tail = next_empty_cell;
539530 // find next empty cell
540531 if (s + 1 < n_seqs) {
541- next_empty_cell += 1 ;
542532 for (uint32_t i = 0 ; i < size; ++i) {
533+ next_empty_cell += 1 ;
543534 if (next_empty_cell >= size) { next_empty_cell -= size; }
544535 kv_cell & cell = cells[next_empty_cell];
545536 if (cell.is_empty ()) { break ; }
546- next_empty_cell += 1 ;
547537 }
548538 }
549539 }
@@ -553,8 +543,8 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
553543
554544 // gather and re-order
555545 for (uint32_t s = 0 ; s < n_seqs; ++s) {
556- int32_t dst_id = s + min;
557- int32_t src_id = cells[ubatch.seq_id [s][0 ]].tail ;
546+ const int32_t dst_id = s + min;
547+ const int32_t src_id = cells[ubatch.seq_id [s][0 ]].tail ;
558548 if (dst_id != src_id) {
559549 kv_cell & dst_cell = cells[dst_id];
560550 kv_cell & src_cell = cells[src_id];
@@ -563,20 +553,22 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
563553 std::swap (dst_cell.src , src_cell.src );
564554 std::swap (dst_cell.seq_id , src_cell.seq_id );
565555
566- // swap tails (assuming they NEVER overlap)
567- for (const llama_seq_id seq_id : src_cell.seq_id ) {
568- cells[seq_id].tail = src_id;
569- }
570- for (const llama_seq_id seq_id : dst_cell.seq_id ) {
571- cells[seq_id].tail = dst_id;
556+ // swap tails
557+ for (uint32_t i = 0 ; i < size; ++i) {
558+ int32_t & tail = cells[i].tail ;
559+ if (tail == src_id) {
560+ tail = dst_id;
561+ } else if (tail == dst_id) {
562+ tail = src_id;
563+ }
572564 }
573565 }
574566 }
575567
576568 // update the pos of the used seqs
577569 for (uint32_t s = 0 ; s < n_seqs; ++s) {
578570 const llama_pos last_pos = ubatch.pos [n_seq_tokens * s + n_seq_tokens - 1 ];
579- int32_t cell_id = s + min;
571+ const int32_t cell_id = s + min;
580572 kv_cell & cell = cells[cell_id];
581573
582574 if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
@@ -594,6 +586,38 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
594586 }
595587 }
596588
589+ // Find first cell without src refs, to use as the zero-ed state
590+ {
591+ // TODO: bake-in src refcounts in the cell metadata
592+ std::vector<int32_t > refcounts (size, 0 );
593+ for (size_t i = 0 ; i < size; ++i) {
594+ const int32_t src = cells[i].src ;
595+ if (src >= 0 ) {
596+ refcounts[src] += 1 ;
597+ }
598+ }
599+
600+ rs_z = -1 ;
601+ for (int i = min; i <= max; ++i) {
602+ if (refcounts[i] == 0 ) {
603+ rs_z = i;
604+ break ;
605+ }
606+ }
607+
608+ for (int i = min; i <= max; ++i) {
609+ if (cells[i].src < 0 ) {
610+ GGML_ASSERT (rs_z >= 0 );
611+ cells[i].src0 = rs_z;
612+ } else {
613+ // Stage the source ids for all used cells to allow correct seq_* behavior
614+ // and still make these values available when setting the inputs
615+ cells[i].src0 = cells[i].src ;
616+ }
617+ cells[i].src = i; // avoid moving or clearing twice
618+ }
619+ }
620+
597621 // allow getting the range of used cells, from head to head + n
598622 head = min;
599623 n = max - min + 1 ;
@@ -605,47 +629,8 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
605629}
606630
607631bool llama_kv_cache_recurrent::get_can_shift () const {
608- return false ;
609- }
610-
611- int32_t llama_kv_cache_recurrent::s_copy (int i) const {
612- const uint32_t cell_id = i + head;
613-
614- // ////////////////////////////////////////////
615- // TODO: this should not mutate the KV cache !
616- kv_cell & cell = const_cast <kv_cell &>(cells[cell_id]);
617-
618- // prevent out-of-bound sources
619- if (cell.src < 0 || (uint32_t ) cell.src >= size) {
620- cell.src = cell_id;
621- }
622-
623- int32_t res = cell.src ;
624-
625- // TODO: do not mutate the KV cache
626- // ensure copy only happens once
627- if (cell.src != (int32_t ) cell_id) {
628- cell.src = cell_id;
629- }
630-
631- return res;
632- }
633-
634- float llama_kv_cache_recurrent::s_mask (int i) const {
635- const uint32_t cell_id = i + head;
636-
637- // ////////////////////////////////////////////
638- // TODO: this should not mutate the KV cache !
639- kv_cell & cell = const_cast <kv_cell &>(cells[cell_id]);
640-
641- float res = (float ) (cell.src >= 0 );
642-
643- // only clear once
644- if (cell.src < 0 ) {
645- cell.src = cell_id;
646- }
647-
648- return res;
632+ // shifting the pos is trivial for recurrent models
633+ return true ;
649634}
650635
651636size_t llama_kv_cache_recurrent::total_size () const {
@@ -1111,6 +1096,10 @@ uint32_t llama_kv_cache_recurrent_state::get_head() const {
11111096 return is_full ? 0 : kv->head ;
11121097}
11131098
1099+ int32_t llama_kv_cache_recurrent_state::get_rs_z () const {
1100+ return is_full ? 0 : kv->rs_z ;
1101+ }
1102+
11141103uint32_t llama_kv_cache_recurrent_state::get_size () const {
11151104 return kv->size ;
11161105}
@@ -1124,9 +1113,5 @@ ggml_tensor * llama_kv_cache_recurrent_state::get_v_l(int32_t il) const {
11241113}
11251114
11261115int32_t llama_kv_cache_recurrent_state::s_copy (int i) const {
1127- return kv->s_copy (i);
1128- }
1129-
1130- float llama_kv_cache_recurrent_state::s_mask (int i) const {
1131- return kv->s_mask (i);
1116+ return kv->cells [i + kv->head ].src0 ;
11321117}
0 commit comments