@@ -334,13 +334,13 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
334334 ubatches.push_back (std::move (ubatch)); // NOLINT
335335 }
336336
337- auto heads = prepare (ubatches);
338- if (heads .empty ()) {
337+ auto sinfos = prepare (ubatches);
338+ if (sinfos .empty ()) {
339339 break ;
340340 }
341341
342342 return std::make_unique<llama_kv_cache_unified_context>(
343- this , std::move (heads ), std::move (ubatches));
343+ this , std::move (sinfos ), std::move (ubatches));
344344 } while (false );
345345
346346 return std::make_unique<llama_kv_cache_unified_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
@@ -383,8 +383,8 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct
383383 return std::make_unique<llama_kv_cache_unified_context>(this , lctx, do_shift, std::move (dinfo));
384384}
385385
386- llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare (const std::vector<llama_ubatch> & ubatches) {
387- llama_kv_cache_unified::ubatch_heads res;
386+ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare (const std::vector<llama_ubatch> & ubatches) {
387+ llama_kv_cache_unified::slot_info_vec_t res;
388388
389389 struct state {
390390 uint32_t head_old; // old position of the head, before placing the ubatch
@@ -400,20 +400,25 @@ llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::
400400
401401 for (const auto & ubatch : ubatches) {
402402 // only find a suitable slot for the ubatch. don't modify the cells yet
403- const int32_t head_new = find_slot (ubatch);
404- if (head_new < 0 ) {
403+ const auto sinfo_new = find_slot (ubatch);
404+ if (sinfo_new. empty () ) {
405405 success = false ;
406406 break ;
407407 }
408408
409409 // remeber the position that we found
410- res.push_back (head_new);
410+ res.push_back (sinfo_new);
411+
412+ // TODO: temporary
413+ if (supports_set_rows) {
414+ GGML_ASSERT (sinfo_new.is_cont ());
415+ }
411416
412417 // store the old state of the cells in the recovery stack
413- states.push_back ({head, ( uint32_t ) head_new , cells.cp (head_new , ubatch.n_tokens )});
418+ states.push_back ({head, sinfo_new. head () , cells.cp (sinfo_new. head () , ubatch.n_tokens )});
414419
415420 // now emplace the ubatch
416- apply_ubatch (head_new , ubatch);
421+ apply_ubatch (sinfo_new , ubatch);
417422 }
418423
419424 // iterate backwards and restore the cells to their original state
@@ -520,7 +525,7 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
520525 return updated;
521526}
522527
523- int32_t llama_kv_cache_unified::find_slot (const llama_ubatch & ubatch) const {
528+ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot (const llama_ubatch & ubatch) const {
524529 const uint32_t n_tokens = ubatch.n_tokens ;
525530
526531 uint32_t head_cur = this ->head ;
@@ -533,7 +538,7 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
533538
534539 if (n_tokens > cells.size ()) {
535540 LLAMA_LOG_ERROR (" %s: n_tokens = %d > size = %u\n " , __func__, n_tokens, cells.size ());
536- return - 1 ;
541+ return { } ;
537542 }
538543
539544 if (debug > 0 ) {
@@ -649,37 +654,48 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
649654
650655 if (n_tested >= cells.size ()) {
651656 // LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
652- return - 1 ;
657+ return { } ;
653658 }
654659 }
655660
656- return head_cur;
661+ slot_info res;
662+
663+ res.idxs .resize (n_tokens);
664+ for (uint32_t i = 0 ; i < n_tokens; ++i) {
665+ res.idxs [i] = head_cur + i;
666+ }
667+
668+ return res;
657669}
658670
659- void llama_kv_cache_unified::apply_ubatch (uint32_t head_cur , const llama_ubatch & ubatch) {
671+ void llama_kv_cache_unified::apply_ubatch (const slot_info & sinfo , const llama_ubatch & ubatch) {
660672 // keep track of the max sequence position that we would overwrite with this ubatch
661673 // for non-SWA cache, this would be always empty
662674 llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
663675 for (int s = 0 ; s < LLAMA_MAX_SEQ; ++s) {
664676 seq_pos_max_rm[s] = -1 ;
665677 }
666678
679+ assert (ubatch.n_tokens == sinfo.idxs .size ());
680+
667681 for (uint32_t i = 0 ; i < ubatch.n_tokens ; ++i) {
668- if (!cells.is_empty (head_cur + i)) {
669- assert (cells.seq_count (head_cur + i) == 1 );
682+ const auto idx = sinfo.idxs [i];
683+
684+ if (!cells.is_empty (idx)) {
685+ assert (cells.seq_count (idx) == 1 );
670686
671- const llama_seq_id seq_id = cells.seq_get (head_cur + i );
672- const llama_pos pos = cells.pos_get (head_cur + i );
687+ const llama_seq_id seq_id = cells.seq_get (idx );
688+ const llama_pos pos = cells.pos_get (idx );
673689
674690 seq_pos_max_rm[seq_id] = std::max (seq_pos_max_rm[seq_id], pos);
675691
676- cells.rm (head_cur + i );
692+ cells.rm (idx );
677693 }
678694
679- cells.pos_set (head_cur + i , ubatch.pos [i]);
695+ cells.pos_set (idx , ubatch.pos [i]);
680696
681697 for (int32_t s = 0 ; s < ubatch.n_seq_id [i]; s++) {
682- cells.seq_add (head_cur + i , ubatch.seq_id [i][s]);
698+ cells.seq_add (idx , ubatch.seq_id [i][s]);
683699 }
684700 }
685701
@@ -700,7 +716,7 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
700716 }
701717
702718 // move the head at the end of the slot
703- head = head_cur + ubatch. n_tokens ;
719+ head = sinfo. idxs . back () + 1 ;
704720}
705721
706722bool llama_kv_cache_unified::get_can_shift () const {
@@ -753,7 +769,7 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint
753769 0 );
754770}
755771
756- ggml_tensor * llama_kv_cache_unified::cpy_k (ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il, uint32_t head_cur ) const {
772+ ggml_tensor * llama_kv_cache_unified::cpy_k (ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il, const slot_info & sinfo ) const {
757773 const int32_t ikv = map_layer_ids.at (il);
758774
759775 auto * k = layers[ikv].k ;
@@ -772,12 +788,12 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
772788
773789 ggml_tensor * k_view = ggml_view_1d (ctx, k,
774790 n_tokens*n_embd_k_gqa,
775- ggml_row_size (k->type , n_embd_k_gqa)*head_cur );
791+ ggml_row_size (k->type , n_embd_k_gqa)*sinfo. head () );
776792
777793 return ggml_cpy (ctx, k_cur, k_view);
778794}
779795
780- ggml_tensor * llama_kv_cache_unified::cpy_v (ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il, uint32_t head_cur ) const {
796+ ggml_tensor * llama_kv_cache_unified::cpy_v (ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il, const slot_info & sinfo ) const {
781797 const int32_t ikv = map_layer_ids.at (il);
782798
783799 auto * v = layers[ikv].v ;
@@ -814,19 +830,19 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
814830 if (!v_trans) {
815831 v_view = ggml_view_1d (ctx, v,
816832 n_tokens*n_embd_v_gqa,
817- ggml_row_size (v->type , n_embd_v_gqa)*head_cur );
833+ ggml_row_size (v->type , n_embd_v_gqa)*sinfo. head () );
818834 } else {
819835 v_cur = ggml_transpose (ctx, v_cur);
820836
821837 v_view = ggml_view_2d (ctx, v, n_tokens, n_embd_v_gqa,
822- (v->ne [1 ])*ggml_element_size (v),
823- (head_cur )*ggml_element_size (v));
838+ (v->ne [1 ] )*ggml_element_size (v),
839+ (sinfo. head () )*ggml_element_size (v));
824840 }
825841
826842 return ggml_cpy (ctx, v_cur, v_view);
827843}
828844
829- void llama_kv_cache_unified::set_input_kv_idxs (ggml_tensor * dst, const llama_ubatch * ubatch, uint32_t head_cur ) const {
845+ void llama_kv_cache_unified::set_input_kv_idxs (ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo ) const {
830846 if (!supports_set_rows) {
831847 return ;
832848 }
@@ -837,7 +853,7 @@ void llama_kv_cache_unified::set_input_kv_idxs(ggml_tensor * dst, const llama_ub
837853 int64_t * data = (int64_t *) dst->data ;
838854
839855 for (int64_t i = 0 ; i < n_tokens; ++i) {
840- data[i] = head_cur + i ;
856+ data[i] = sinfo. idxs [i] ;
841857 }
842858}
843859
@@ -1580,13 +1596,15 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
15801596 ubatch.seq_id [i] = &dest_seq_id;
15811597 }
15821598
1583- const auto head_cur = find_slot (ubatch);
1584- if (head_cur < 0 ) {
1599+ const auto sinfo = find_slot (ubatch);
1600+ if (sinfo. empty () ) {
15851601 LLAMA_LOG_ERROR (" %s: failed to find available cells in kv cache\n " , __func__);
15861602 return false ;
15871603 }
15881604
1589- apply_ubatch (head_cur, ubatch);
1605+ apply_ubatch (sinfo, ubatch);
1606+
1607+ const auto head_cur = sinfo.head ();
15901608
15911609 // keep the head at the old position because we will read the KV data into it in state_read_data()
15921610 head = head_cur;
@@ -1772,7 +1790,10 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context(llama_memory_stat
17721790llama_kv_cache_unified_context::llama_kv_cache_unified_context (
17731791 llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
17741792 n_kv = kv->get_size ();
1775- head = 0 ;
1793+
1794+ sinfos.resize (1 );
1795+ sinfos[0 ].idxs .resize (1 );
1796+ sinfos[0 ].idxs [0 ] = 0 ;
17761797}
17771798
17781799llama_kv_cache_unified_context::llama_kv_cache_unified_context (
@@ -1787,16 +1808,16 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context(
17871808
17881809llama_kv_cache_unified_context::llama_kv_cache_unified_context (
17891810 llama_kv_cache_unified * kv,
1790- llama_kv_cache_unified::ubatch_heads heads ,
1791- std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), heads (std::move(heads )), ubatches(std::move(ubatches)) {
1811+ llama_kv_cache_unified::slot_info_vec_t sinfos ,
1812+ std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos (std::move(sinfos )), ubatches(std::move(ubatches)) {
17921813}
17931814
17941815llama_kv_cache_unified_context::~llama_kv_cache_unified_context () = default ;
17951816
17961817bool llama_kv_cache_unified_context::next () {
17971818 assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
17981819
1799- if (++i_next >= ubatches.size ()) {
1820+ if (++i_cur >= ubatches.size ()) {
18001821 return false ;
18011822 }
18021823
@@ -1813,10 +1834,9 @@ bool llama_kv_cache_unified_context::apply() {
18131834 return true ;
18141835 }
18151836
1816- kv->apply_ubatch (heads[i_next ], ubatches[i_next ]);
1837+ kv->apply_ubatch (sinfos[i_cur ], ubatches[i_cur ]);
18171838
18181839 n_kv = kv->get_n_kv ();
1819- head = heads[i_next];
18201840
18211841 return true ;
18221842}
@@ -1828,7 +1848,7 @@ llama_memory_status llama_kv_cache_unified_context::get_status() const {
18281848const llama_ubatch & llama_kv_cache_unified_context::get_ubatch () const {
18291849 assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
18301850
1831- return ubatches[i_next ];
1851+ return ubatches[i_cur ];
18321852}
18331853
18341854uint32_t llama_kv_cache_unified_context::get_n_kv () const {
@@ -1844,19 +1864,19 @@ ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t
18441864}
18451865
18461866ggml_tensor * llama_kv_cache_unified_context::cpy_k (ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il) const {
1847- return kv->cpy_k (ctx, k_cur, kv_idxs, il, head );
1867+ return kv->cpy_k (ctx, k_cur, kv_idxs, il, sinfos[i_cur] );
18481868}
18491869
18501870ggml_tensor * llama_kv_cache_unified_context::cpy_v (ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il) const {
1851- return kv->cpy_v (ctx, v_cur, kv_idxs, il, head );
1871+ return kv->cpy_v (ctx, v_cur, kv_idxs, il, sinfos[i_cur] );
18521872}
18531873
18541874void llama_kv_cache_unified_context::set_input_k_shift (ggml_tensor * dst) const {
18551875 kv->set_input_k_shift (dst);
18561876}
18571877
18581878void llama_kv_cache_unified_context::set_input_kv_idxs (ggml_tensor * dst, const llama_ubatch * ubatch) const {
1859- kv->set_input_kv_idxs (dst, ubatch, head );
1879+ kv->set_input_kv_idxs (dst, ubatch, sinfos[i_cur] );
18601880}
18611881
18621882void llama_kv_cache_unified_context::set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
0 commit comments