77#include < cassert>
88#include < vector>
99#include < set>
10- #include < map>
1110
1211// meta information about KV cells that can be part of multiple sequences at the same time
1312// TODO: add unit tests
@@ -217,7 +216,7 @@ class llama_kv_cells_unified {
217216 assert (seq_id >= 0 );
218217
219218 seq[i].reset (seq_id);
220- seq_pos_dec ( seq_id, pos[i]);
219+ seq_pos[ seq_id]. erase ( pos[i]);
221220
222221 if (seq[i].none ()) {
223222 pos[i] = -1 ;
@@ -240,7 +239,7 @@ class llama_kv_cells_unified {
240239 seq[i].reset ();
241240
242241 seq[i].set (seq_id);
243- seq_pos_inc ( seq_id, pos[i]);
242+ seq_pos[ seq_id]. insert ( pos[i]);
244243
245244 return false ;
246245 }
@@ -285,7 +284,7 @@ class llama_kv_cells_unified {
285284 assert (!seq[i].test (seq_id));
286285
287286 seq[i].set (seq_id);
288- seq_pos_inc ( seq_id, pos[i]);
287+ seq_pos[ seq_id]. insert ( pos[i]);
289288 }
290289
291290 // return the sequence id of this cell
@@ -312,9 +311,7 @@ class llama_kv_cells_unified {
312311 return -1 ;
313312 }
314313
315- assert (seq_pos[seq_id].begin ()->second > 0 );
316-
317- return seq_pos[seq_id].begin ()->first ;
314+ return *seq_pos[seq_id].begin ();
318315 }
319316
320317 // the maximum position of sequence seq_id currently present in any of the cells
@@ -327,9 +324,7 @@ class llama_kv_cells_unified {
327324 return -1 ;
328325 }
329326
330- assert (seq_pos[seq_id].rbegin ()->second > 0 );
331-
332- return seq_pos[seq_id].rbegin ()->first ;
327+ return *seq_pos[seq_id].rbegin ();
333328 }
334329
335330 // note: call only if the cell is not empty
@@ -441,41 +436,22 @@ class llama_kv_cells_unified {
441436 //
442437 std::vector<llama_pos> shift;
443438
444- using seq_set_t = std::bitset<LLAMA_MAX_SEQ>;
439+ using bits_t = std::bitset<LLAMA_MAX_SEQ>;
445440
446441 // the bitset seq[i] tells us which sequences are currently occupying the i-th cell
447- std::vector<seq_set_t > seq;
442+ std::vector<bits_t > seq;
448443
449- // the set seq_pos[s][p] tells us how many times the position p is currently present for sequence s
450- // if the position p is not present, seq_pos[s][p] is not set
444+ // the set seq_pos[s] tells us which positions are currently present for sequence s
451445 // this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
452- //
453- // note that we cannot a use an std::set because in some cases a position can occur more than once for the same seq:
454- // - during performing a cache reuse via (rm + add)
455- // - some vision models have input embeddings with repeating positions
456- //
457- std::map<llama_pos, int > seq_pos[LLAMA_MAX_SEQ];
446+ std::set<llama_pos> seq_pos[LLAMA_MAX_SEQ];
458447
459448 // helper functions for updating `seq_pos`, once cell at a time:
460449
461- void seq_pos_dec (llama_seq_id s, llama_pos p) {
462- auto it = seq_pos[s].find (p);
463- assert (it != seq_pos[s].end ());
464-
465- if (--it->second == 0 ) {
466- seq_pos[s].erase (it);
467- }
468- }
469-
470- void seq_pos_inc (llama_seq_id s, llama_pos p) {
471- seq_pos[s][p]++;
472- }
473-
474450 // remove cell i
475451 void seq_pos_rm (uint32_t i) {
476452 for (int s = 0 ; s < LLAMA_MAX_SEQ; ++s) {
477453 if (seq[i].test (s)) {
478- seq_pos_dec (s, pos[i]);
454+ seq_pos[s]. erase ( pos[i]);
479455 }
480456 }
481457 }
@@ -484,7 +460,7 @@ class llama_kv_cells_unified {
484460 void seq_pos_add (uint32_t i) {
485461 for (int s = 0 ; s < LLAMA_MAX_SEQ; ++s) {
486462 if (seq[i].test (s)) {
487- seq_pos_inc (s, pos[i]);
463+ seq_pos[s]. insert ( pos[i]);
488464 }
489465 }
490466 }
0 commit comments