@@ -400,8 +400,11 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st
400400 bool success = true ;
401401
402402 for (const auto & ubatch : ubatches) {
403+ // non-continuous slots require support for ggml_set_rows()
404+ const bool cont = supports_set_rows ? false : true ;
405+
403406 // only find a suitable slot for the ubatch. don't modify the cells yet
404- const auto sinfo_new = find_slot (ubatch);
407+ const auto sinfo_new = find_slot (ubatch, cont );
405408 if (sinfo_new.empty ()) {
406409 success = false ;
407410 break ;
@@ -521,7 +524,7 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
521524 return updated;
522525}
523526
524- llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot (const llama_ubatch & ubatch) const {
527+ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot (const llama_ubatch & ubatch, bool cont ) const {
525528 const uint32_t n_tokens = ubatch.n_tokens ;
526529
527530 uint32_t head_cur = this ->head ;
@@ -595,17 +598,25 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_
595598 }
596599 }
597600
601+ uint32_t n_found = 0 ;
598602 uint32_t n_tested = 0 ;
599603
604+ const uint32_t n_test = cont ? n_tokens : 1 ;
605+
606+ slot_info res;
607+
608+ res.idxs .resize (n_tokens);
609+
600610 while (true ) {
601- if (head_cur + n_tokens > cells.size ()) {
611+ if (head_cur + n_test > cells.size ()) {
602612 n_tested += cells.size () - head_cur;
603613 head_cur = 0 ;
604614 continue ;
605615 }
606616
607- bool found = true ;
608- for (uint32_t i = 0 ; i < n_tokens; i++) {
617+ for (uint32_t i = 0 ; i < n_test; i++) {
618+ const auto idx = head_cur;
619+
609620 // const llama_pos pos = ubatch.pos[i];
610621 // const llama_seq_id seq_id = ubatch.seq_id[i][0];
611622
@@ -615,19 +626,19 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_
615626 // - (disabled) mask causally, if the sequence is the same as the one we are inserting
616627 // - mask SWA, using current max pos for that sequence in the cache
617628 // always insert in the cell with minimum pos
618- bool can_use = cells.is_empty (head_cur + i );
629+ bool can_use = cells.is_empty (idx );
619630
620- if (!can_use && cells.seq_count (head_cur + i ) == 1 ) {
621- const llama_pos pos_cell = cells.pos_get (head_cur + i );
631+ if (!can_use && cells.seq_count (idx ) == 1 ) {
632+ const llama_pos pos_cell = cells.pos_get (idx );
622633
623634 // (disabled) causal mask
624635 // note: it's better to purge any "future" tokens beforehand
625- // if (cells.seq_has(head_cur + i, seq_id )) {
636+ // if (cells.seq_has(idx )) {
626637 // can_use = pos_cell >= pos;
627638 // }
628639
629640 if (!can_use) {
630- const llama_seq_id seq_id_cell = cells.seq_get (head_cur + i );
641+ const llama_seq_id seq_id_cell = cells.seq_get (idx );
631642
632643 // SWA mask
633644 if (is_masked_swa (pos_cell, cells.seq_pos_max (seq_id_cell) + 1 )) {
@@ -636,29 +647,35 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_
636647 }
637648 }
638649
639- if (!can_use) {
640- found = false ;
641- head_cur += i + 1 ;
642- n_tested += i + 1 ;
650+ head_cur++;
651+ n_tested++;
652+
653+ if (can_use) {
654+ res.idxs [n_found] = idx;
655+
656+ n_found++;
657+ } else {
643658 break ;
644659 }
645660 }
646661
647- if (found ) {
662+ if (n_found == n_tokens ) {
648663 break ;
649664 }
650665
666+ if (cont) {
667+ n_found = 0 ;
668+ }
669+
651670 if (n_tested >= cells.size ()) {
652671 // LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
653672 return { };
654673 }
655674 }
656675
657- slot_info res;
658-
659- res.idxs .resize (n_tokens);
660- for (uint32_t i = 0 ; i < n_tokens; ++i) {
661- res.idxs [i] = head_cur + i;
676+ // we didn't find a suitable slot - return empty result
677+ if (n_found < n_tokens) {
678+ res.clear ();
662679 }
663680
664681 return res;
@@ -1592,7 +1609,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
15921609 ubatch.seq_id [i] = &dest_seq_id;
15931610 }
15941611
1595- const auto sinfo = find_slot (ubatch);
1612+ const auto sinfo = find_slot (ubatch, true );
15961613 if (sinfo.empty ()) {
15971614 LLAMA_LOG_ERROR (" %s: failed to find available cells in kv cache\n " , __func__);
15981615 return false ;
0 commit comments