7
7
#include < cassert>
8
8
#include < vector>
9
9
#include < set>
10
+ #include < map>
10
11
11
12
// meta information about KV cells that can be part of multiple sequences at the same time
12
13
// TODO: add unit tests
@@ -164,7 +165,7 @@ class llama_kv_cells_unified {
164
165
assert (seq_id >= 0 );
165
166
166
167
seq[i].reset (seq_id);
167
- seq_pos[seq_id]. erase ( pos[i]);
168
+ seq_pos_dec (seq_id, pos[i]);
168
169
169
170
if (seq[i].none ()) {
170
171
pos[i] = -1 ;
@@ -187,7 +188,7 @@ class llama_kv_cells_unified {
187
188
seq[i].reset ();
188
189
189
190
seq[i].set (seq_id);
190
- seq_pos[seq_id]. insert ( pos[i]);
191
+ seq_pos_inc (seq_id, pos[i]);
191
192
192
193
return false ;
193
194
}
@@ -232,7 +233,7 @@ class llama_kv_cells_unified {
232
233
assert (!seq[i].test (seq_id));
233
234
234
235
seq[i].set (seq_id);
235
- seq_pos[seq_id]. insert ( pos[i]);
236
+ seq_pos_inc (seq_id, pos[i]);
236
237
}
237
238
238
239
// return the sequence id of this cell
@@ -259,7 +260,9 @@ class llama_kv_cells_unified {
259
260
return -1 ;
260
261
}
261
262
262
- return *seq_pos[seq_id].begin ();
263
+ assert (seq_pos[seq_id].begin ()->second > 0 );
264
+
265
+ return seq_pos[seq_id].begin ()->first ;
263
266
}
264
267
265
268
// the maximum position of sequence seq_id currently present in any of the cells
@@ -272,7 +275,9 @@ class llama_kv_cells_unified {
272
275
return -1 ;
273
276
}
274
277
275
- return *seq_pos[seq_id].rbegin ();
278
+ assert (seq_pos[seq_id].rbegin ()->second > 0 );
279
+
280
+ return seq_pos[seq_id].rbegin ()->first ;
276
281
}
277
282
278
283
// note: call only if the cell is not empty
@@ -389,17 +394,36 @@ class llama_kv_cells_unified {
389
394
// the bitset seq[i] tells us which sequences are currently occupying the i-th cell
390
395
std::vector<seq_set_t > seq;
391
396
392
- // the set seq_pos[s] tells us which positions are currently present for sequence s
397
+ // the set seq_pos[s][p] tells us how many times the position p is currently present for sequence s
398
+ // if the position p is not present, seq_pos[s][p] is not set
393
399
// this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
394
- std::set<llama_pos> seq_pos[LLAMA_MAX_SEQ];
400
+ //
401
+ // note that we cannot a use an std::set because in some cases a position can occur more than once for the same seq:
402
+ // - during performing a cache reuse via (rm + add)
403
+ // - some vision models have input embeddings with repeating positions
404
+ //
405
+ std::map<llama_pos, int > seq_pos[LLAMA_MAX_SEQ];
395
406
396
407
// helper functions for updating `seq_pos`, once cell at a time:
397
408
409
+ void seq_pos_dec (llama_seq_id s, llama_pos p) {
410
+ auto it = seq_pos[s].find (p);
411
+ assert (it != seq_pos[s].end ());
412
+
413
+ if (--it->second == 0 ) {
414
+ seq_pos[s].erase (it);
415
+ }
416
+ }
417
+
418
+ void seq_pos_inc (llama_seq_id s, llama_pos p) {
419
+ seq_pos[s][p]++;
420
+ }
421
+
398
422
// remove cell i
399
423
void seq_pos_rm (uint32_t i) {
400
424
for (int s = 0 ; s < LLAMA_MAX_SEQ; ++s) {
401
425
if (seq[i].test (s)) {
402
- seq_pos[s]. erase ( pos[i]);
426
+ seq_pos_dec (s, pos[i]);
403
427
}
404
428
}
405
429
}
@@ -408,7 +432,7 @@ class llama_kv_cells_unified {
408
432
void seq_pos_add (uint32_t i) {
409
433
for (int s = 0 ; s < LLAMA_MAX_SEQ; ++s) {
410
434
if (seq[i].test (s)) {
411
- seq_pos[s]. insert ( pos[i]);
435
+ seq_pos_inc (s, pos[i]);
412
436
}
413
437
}
414
438
}
0 commit comments