77#include < cassert>
88#include < vector>
99#include < set>
10+ #include < map>
1011
1112// meta information about KV cells that can be part of multiple sequences at the same time
1213// TODO: add unit tests
@@ -164,7 +165,7 @@ class llama_kv_cells_unified {
164165 assert (seq_id >= 0 );
165166
166167 seq[i].reset (seq_id);
167- seq_pos[seq_id]. erase ( pos[i]);
168+ seq_pos_dec (seq_id, pos[i]);
168169
169170 if (seq[i].none ()) {
170171 pos[i] = -1 ;
@@ -187,7 +188,7 @@ class llama_kv_cells_unified {
187188 seq[i].reset ();
188189
189190 seq[i].set (seq_id);
190- seq_pos[seq_id]. insert ( pos[i]);
191+ seq_pos_inc (seq_id, pos[i]);
191192
192193 return false ;
193194 }
@@ -232,7 +233,7 @@ class llama_kv_cells_unified {
232233 assert (!seq[i].test (seq_id));
233234
234235 seq[i].set (seq_id);
235- seq_pos[seq_id]. insert ( pos[i]);
236+ seq_pos_inc (seq_id, pos[i]);
236237 }
237238
238239 // return the sequence id of this cell
@@ -259,7 +260,9 @@ class llama_kv_cells_unified {
259260 return -1 ;
260261 }
261262
262- return *seq_pos[seq_id].begin ();
263+ assert (seq_pos[seq_id].begin ()->second > 0 );
264+
265+ return seq_pos[seq_id].begin ()->first ;
263266 }
264267
265268 // the maximum position of sequence seq_id currently present in any of the cells
@@ -272,7 +275,9 @@ class llama_kv_cells_unified {
272275 return -1 ;
273276 }
274277
275- return *seq_pos[seq_id].rbegin ();
278+ assert (seq_pos[seq_id].rbegin ()->second > 0 );
279+
280+ return seq_pos[seq_id].rbegin ()->first ;
276281 }
277282
278283 // note: call only if the cell is not empty
@@ -389,17 +394,36 @@ class llama_kv_cells_unified {
389394 // the bitset seq[i] tells us which sequences are currently occupying the i-th cell
390395 std::vector<seq_set_t > seq;
391396
392- // the set seq_pos[s] tells us which positions are currently present for sequence s
397+ // the set seq_pos[s][p] tells us how many times the position p is currently present for sequence s
398+ // if the position p is not present, seq_pos[s][p] is not set
393399 // this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
394- std::set<llama_pos> seq_pos[LLAMA_MAX_SEQ];
400+ //
401+ // note that we cannot a use an std::set because in some cases a position can occur more than once for the same seq:
402+ // - during performing a cache reuse via (rm + add)
403+ // - some vision models have input embeddings with repeating positions
404+ //
405+ std::map<llama_pos, int > seq_pos[LLAMA_MAX_SEQ];
395406
396407 // helper functions for updating `seq_pos`, once cell at a time:
397408
409+ void seq_pos_dec (llama_seq_id s, llama_pos p) {
410+ auto it = seq_pos[s].find (p);
411+ assert (it != seq_pos[s].end ());
412+
413+ if (--it->second == 0 ) {
414+ seq_pos[s].erase (it);
415+ }
416+ }
417+
418+ void seq_pos_inc (llama_seq_id s, llama_pos p) {
419+ seq_pos[s][p]++;
420+ }
421+
398422 // remove cell i
399423 void seq_pos_rm (uint32_t i) {
400424 for (int s = 0 ; s < LLAMA_MAX_SEQ; ++s) {
401425 if (seq[i].test (s)) {
402- seq_pos[s]. erase ( pos[i]);
426+ seq_pos_dec (s, pos[i]);
403427 }
404428 }
405429 }
@@ -408,7 +432,7 @@ class llama_kv_cells_unified {
408432 void seq_pos_add (uint32_t i) {
409433 for (int s = 0 ; s < LLAMA_MAX_SEQ; ++s) {
410434 if (seq[i].test (s)) {
411- seq_pos[s]. insert ( pos[i]);
435+ seq_pos_inc (s, pos[i]);
412436 }
413437 }
414438 }
0 commit comments