77#include < cassert>
88#include < vector>
99#include < set>
10- #include < map>
1110
1211// meta information about KV cells that can be part of multiple sequences at the same time
1312// TODO: add unit tests
@@ -165,7 +164,7 @@ class llama_kv_cells_unified {
165164 assert (seq_id >= 0 );
166165
167166 seq[i].reset (seq_id);
168- seq_pos_dec ( seq_id, pos[i]);
167+ seq_pos[ seq_id]. erase ( pos[i]);
169168
170169 if (seq[i].none ()) {
171170 pos[i] = -1 ;
@@ -188,7 +187,7 @@ class llama_kv_cells_unified {
188187 seq[i].reset ();
189188
190189 seq[i].set (seq_id);
191- seq_pos_inc ( seq_id, pos[i]);
190+ seq_pos[ seq_id]. insert ( pos[i]);
192191
193192 return false ;
194193 }
@@ -233,7 +232,7 @@ class llama_kv_cells_unified {
233232 assert (!seq[i].test (seq_id));
234233
235234 seq[i].set (seq_id);
236- seq_pos_inc ( seq_id, pos[i]);
235+ seq_pos[ seq_id]. insert ( pos[i]);
237236 }
238237
239238 // return the sequence id of this cell
@@ -260,9 +259,7 @@ class llama_kv_cells_unified {
260259 return -1 ;
261260 }
262261
263- assert (seq_pos[seq_id].begin ()->second > 0 );
264-
265- return seq_pos[seq_id].begin ()->first ;
262+ return *seq_pos[seq_id].begin ();
266263 }
267264
268265 // the maximum position of sequence seq_id currently present in any of the cells
@@ -275,9 +272,7 @@ class llama_kv_cells_unified {
275272 return -1 ;
276273 }
277274
278- assert (seq_pos[seq_id].rbegin ()->second > 0 );
279-
280- return seq_pos[seq_id].rbegin ()->first ;
275+ return *seq_pos[seq_id].rbegin ();
281276 }
282277
283278 // note: call only if the cell is not empty
@@ -389,41 +384,22 @@ class llama_kv_cells_unified {
389384 //
390385 std::vector<llama_pos> shift;
391386
392- using seq_set_t = std::bitset<LLAMA_MAX_SEQ>;
387+ using bits_t = std::bitset<LLAMA_MAX_SEQ>;
393388
394389 // the bitset seq[i] tells us which sequences are currently occupying the i-th cell
395- std::vector<seq_set_t > seq;
390+ std::vector<bits_t > seq;
396391
397- // the set seq_pos[s][p] tells us how many times the position p is currently present for sequence s
398- // if the position p is not present, seq_pos[s][p] is not set
392+ // the set seq_pos[s] tells us which positions are currently present for sequence s
399393 // this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
400- //
401- // note that we cannot a use an std::set because in some cases a position can occur more than once for the same seq:
402- // - during performing a cache reuse via (rm + add)
403- // - some vision models have input embeddings with repeating positions
404- //
405- std::map<llama_pos, int > seq_pos[LLAMA_MAX_SEQ];
394+ std::set<llama_pos> seq_pos[LLAMA_MAX_SEQ];
406395
407396 // helper functions for updating `seq_pos`, once cell at a time:
408397
409- void seq_pos_dec (llama_seq_id s, llama_pos p) {
410- auto it = seq_pos[s].find (p);
411- assert (it != seq_pos[s].end ());
412-
413- if (--it->second == 0 ) {
414- seq_pos[s].erase (it);
415- }
416- }
417-
418- void seq_pos_inc (llama_seq_id s, llama_pos p) {
419- seq_pos[s][p]++;
420- }
421-
422398 // remove cell i
423399 void seq_pos_rm (uint32_t i) {
424400 for (int s = 0 ; s < LLAMA_MAX_SEQ; ++s) {
425401 if (seq[i].test (s)) {
426- seq_pos_dec (s, pos[i]);
402+ seq_pos[s]. erase ( pos[i]);
427403 }
428404 }
429405 }
@@ -432,7 +408,7 @@ class llama_kv_cells_unified {
432408 void seq_pos_add (uint32_t i) {
433409 for (int s = 0 ; s < LLAMA_MAX_SEQ; ++s) {
434410 if (seq[i].test (s)) {
435- seq_pos_inc (s, pos[i]);
411+ seq_pos[s]. insert ( pos[i]);
436412 }
437413 }
438414 }
0 commit comments