66#include < bitset>
77#include < cassert>
88#include < vector>
9+ #include < set>
910
1011// meta information about KV cells that can be part of multiple sequences at the same time
1112// TODO: add unit tests
@@ -18,8 +19,13 @@ class llama_kv_cells_unified {
1819 seq[i].reset ();
1920 }
2021
21- used = 0 ;
2222 has_shift = false ;
23+
24+ used.clear ();
25+
26+ for (uint32_t s = 0 ; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
27+ seq_pos[s].clear ();
28+ }
2329 }
2430
2531 void reset_shift () {
@@ -50,7 +56,25 @@ class llama_kv_cells_unified {
5056 }
5157
5258 uint32_t get_used () const {
53- return used;
59+ return used.size ();
60+ }
61+
62+ // the index of the first cell that is used
63+ // return 0 if no cells are used
64+ uint32_t used_min () const {
65+ return used.empty () ? 0 : *used.begin ();
66+ }
67+
68+ // the index of the last cell that is used + 1
69+ // return 0 if no cells are used
70+ uint32_t used_max_p1 () const {
71+ #if 0
72+ if (!seq_pos[0].empty()) printf("kv_cells: min[0] = %5d, max[0] = %5d\n", *seq_pos[0].begin(), *seq_pos[0].rbegin());
73+ if (!seq_pos[1].empty()) printf("kv_cells: min[1] = %5d, max[1] = %5d\n", *seq_pos[1].begin(), *seq_pos[1].rbegin());
74+ if (!seq_pos[2].empty()) printf("kv_cells: min[2] = %5d, max[2] = %5d\n", *seq_pos[2].begin(), *seq_pos[2].rbegin());
75+ #endif
76+
77+ return used.empty () ? 0 : *used.rbegin () + 1 ;
5478 }
5579
5680 bool get_has_shift () const {
@@ -69,6 +93,9 @@ class llama_kv_cells_unified {
6993 pos [isrc] = -1 ;
7094 shift[isrc] = 0 ;
7195 seq [isrc].reset ();
96+
97+ used.erase (isrc);
98+ used.insert (idst);
7299 }
73100
74101 // copy the state of cells [i, i + n) (used for save/restore the state of the cells)
@@ -95,16 +122,24 @@ class llama_kv_cells_unified {
95122
96123 for (uint32_t j = 0 ; j < other.pos .size (); ++j) {
97124 if (pos[i + j] == -1 && other.pos [j] != -1 ) {
98- used++ ;
125+ used. insert (i + j) ;
99126 }
100127
101128 if (pos[i + j] != -1 && other.pos [j] == -1 ) {
102- used--;
129+ used.erase (i + j);
130+ }
131+
132+ if (pos[i + j] != -1 ) {
133+ seq_pos_rm (i + j);
103134 }
104135
105136 pos[i + j] = other.pos [j];
106137 seq[i + j] = other.seq [j];
107138
139+ if (pos[i + j] != -1 ) {
140+ seq_pos_add (i + j);
141+ }
142+
108143 assert (shift[i + j] == 0 );
109144 }
110145 }
@@ -118,11 +153,12 @@ class llama_kv_cells_unified {
118153 assert (seq_id >= 0 );
119154
120155 seq[i].reset (seq_id);
156+ seq_pos[seq_id].erase (pos[i]);
121157
122158 if (seq[i].none ()) {
123159 pos[i] = -1 ;
124160
125- used-- ;
161+ used. erase (i) ;
126162
127163 return true ;
128164 }
@@ -135,17 +171,22 @@ class llama_kv_cells_unified {
135171 assert (i < pos.size ());
136172
137173 if (seq[i].test (seq_id)) {
174+ seq_pos_rm (i);
138175 seq[i].reset ();
176+
139177 seq[i].set (seq_id);
178+ seq_pos[seq_id].insert (pos[i]);
140179
141180 return false ;
142181 }
143182
144183 if (seq[i].any ()) {
184+ seq_pos_rm (i);
145185 seq[i].reset ();
186+
146187 pos[i] = -1 ;
147188
148- used-- ;
189+ used. erase (i) ;
149190
150191 return true ;
151192 }
@@ -169,6 +210,33 @@ class llama_kv_cells_unified {
169210 assert (!seq[i].test (seq_id));
170211
171212 seq[i].set (seq_id);
213+ seq_pos[seq_id].insert (pos[i]);
214+ }
215+
216+ // the minimum position of sequence seq_id currently present in any of the cells
217+ // return -1 if the sequence is not present
218+ llama_pos seq_pos_min (llama_seq_id seq_id) const {
219+ assert (seq_id >= 0 );
220+ assert (seq_id < LLAMA_MAX_PARALLEL_SEQUENCES);
221+
222+ if (seq_pos[seq_id].empty ()) {
223+ return -1 ;
224+ }
225+
226+ return *seq_pos[seq_id].begin ();
227+ }
228+
229+ // the maximum position of sequence seq_id currently present in any of the cells
230+ // return -1 if the sequence is not present
231+ llama_pos seq_pos_max (llama_seq_id seq_id) const {
232+ assert (seq_id >= 0 );
233+ assert (seq_id < LLAMA_MAX_PARALLEL_SEQUENCES);
234+
235+ if (seq_pos[seq_id].empty ()) {
236+ return -1 ;
237+ }
238+
239+ return *seq_pos[seq_id].rbegin ();
172240 }
173241
174242 // note: call only if the cell is not empty
@@ -202,7 +270,8 @@ class llama_kv_cells_unified {
202270 assert (pos[i] == -1 );
203271
204272 pos[i] = p;
205- used++;
273+
274+ used.insert (i);
206275 }
207276
208277 // pos[i] = pos[i] + d
@@ -212,16 +281,22 @@ class llama_kv_cells_unified {
212281 assert (i < pos.size ());
213282 assert (pos[i] != -1 );
214283
284+ seq_pos_rm (i);
285+
215286 pos[i] += d;
216287 shift[i] += d;
217288
289+ seq_pos_add (i);
290+
218291 has_shift = true ;
219292
220293 if (pos[i] < 0 ) {
221- pos[i] = -1 ;
294+ seq_pos_rm (i);
295+
222296 seq[i].reset ();
297+ pos[i] = -1 ;
223298
224- used-- ;
299+ used. erase (i) ;
225300
226301 return true ;
227302 }
@@ -238,17 +313,22 @@ class llama_kv_cells_unified {
238313
239314 const llama_pos p_old = pos[i];
240315
316+ seq_pos_rm (i);
317+
241318 pos[i] /= d;
242319 shift[i] += p_old - pos[i];
243320
321+ seq_pos_add (i);
322+
244323 has_shift = true ;
245324 }
246325
247326private:
248- uint32_t used = 0 ; // used cells (i.e. pos[i] != -1, allowed to not have any seq_id)
249-
250327 bool has_shift = false ;
251328
329+ // set of indices of used cells (i.e. pos[i] != -1, allowed to not have any seq_id)
330+ std::set<uint32_t > used;
331+
252332 std::vector<llama_pos> pos;
253333
254334 // this array accumulates any applied shifts to the pos array since the last reset_shift() call
@@ -268,6 +348,32 @@ class llama_kv_cells_unified {
268348 //
269349 std::vector<llama_pos> shift;
270350
271- std::vector<std::bitset<LLAMA_MAX_PARALLEL_SEQUENCES>> seq;
272- };
351+ using bits_t = std::bitset<LLAMA_MAX_PARALLEL_SEQUENCES>;
352+
353+ // the bitset seq[i] tells us which sequences are currently occupying the i-th cell
354+ std::vector<bits_t > seq;
355+
356+ // the set seq_pos[s] tells us which positions are currently present for sequence s
357+ // this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
358+ std::set<llama_pos> seq_pos[LLAMA_MAX_PARALLEL_SEQUENCES];
359+
360+ // helper functions for updating `seq_pos`, once cell at a time:
361+
362+ // remove cell i
363+ void seq_pos_rm (uint32_t i) {
364+ for (int s = 0 ; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
365+ if (seq[i].test (s)) {
366+ seq_pos[s].erase (pos[i]);
367+ }
368+ }
369+ }
273370
371+ // add cell i
372+ void seq_pos_add (uint32_t i) {
373+ for (int s = 0 ; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
374+ if (seq[i].test (s)) {
375+ seq_pos[s].insert (pos[i]);
376+ }
377+ }
378+ }
379+ };
0 commit comments