@@ -23,7 +23,7 @@ class llama_kv_cells_unified {
23
23
24
24
used.clear ();
25
25
26
- for (uint32_t s = 0 ; s < LLAMA_MAX_PARALLEL_SEQUENCES ; ++s) {
26
+ for (uint32_t s = 0 ; s < LLAMA_MAX_SEQ ; ++s) {
27
27
seq_pos[s].clear ();
28
28
}
29
29
}
@@ -240,7 +240,7 @@ class llama_kv_cells_unified {
240
240
llama_seq_id seq_get (uint32_t i) const {
241
241
assert (seq[i].count () == 1 );
242
242
243
- for (int s = 0 ; s < LLAMA_MAX_PARALLEL_SEQUENCES ; ++s) {
243
+ for (int s = 0 ; s < LLAMA_MAX_SEQ ; ++s) {
244
244
if (seq[i].test (s)) {
245
245
return s;
246
246
}
@@ -253,7 +253,7 @@ class llama_kv_cells_unified {
253
253
// return -1 if the sequence is not present
254
254
llama_pos seq_pos_min (llama_seq_id seq_id) const {
255
255
assert (seq_id >= 0 );
256
- assert (seq_id < LLAMA_MAX_PARALLEL_SEQUENCES );
256
+ assert (seq_id < LLAMA_MAX_SEQ );
257
257
258
258
if (seq_pos[seq_id].empty ()) {
259
259
return -1 ;
@@ -266,7 +266,7 @@ class llama_kv_cells_unified {
266
266
// return -1 if the sequence is not present
267
267
llama_pos seq_pos_max (llama_seq_id seq_id) const {
268
268
assert (seq_id >= 0 );
269
- assert (seq_id < LLAMA_MAX_PARALLEL_SEQUENCES );
269
+ assert (seq_id < LLAMA_MAX_SEQ );
270
270
271
271
if (seq_pos[seq_id].empty ()) {
272
272
return -1 ;
@@ -384,20 +384,20 @@ class llama_kv_cells_unified {
384
384
//
385
385
std::vector<llama_pos> shift;
386
386
387
- using bits_t = std::bitset<LLAMA_MAX_PARALLEL_SEQUENCES >;
387
+ using bits_t = std::bitset<LLAMA_MAX_SEQ >;
388
388
389
389
// the bitset seq[i] tells us which sequences are currently occupying the i-th cell
390
390
std::vector<bits_t > seq;
391
391
392
392
// the set seq_pos[s] tells us which positions are currently present for sequence s
393
393
// this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
394
- std::set<llama_pos> seq_pos[LLAMA_MAX_PARALLEL_SEQUENCES ];
394
+ std::set<llama_pos> seq_pos[LLAMA_MAX_SEQ ];
395
395
396
396
// helper functions for updating `seq_pos`, once cell at a time:
397
397
398
398
// remove cell i
399
399
void seq_pos_rm (uint32_t i) {
400
- for (int s = 0 ; s < LLAMA_MAX_PARALLEL_SEQUENCES ; ++s) {
400
+ for (int s = 0 ; s < LLAMA_MAX_SEQ ; ++s) {
401
401
if (seq[i].test (s)) {
402
402
seq_pos[s].erase (pos[i]);
403
403
}
@@ -406,7 +406,7 @@ class llama_kv_cells_unified {
406
406
407
407
// add cell i
408
408
void seq_pos_add (uint32_t i) {
409
- for (int s = 0 ; s < LLAMA_MAX_PARALLEL_SEQUENCES ; ++s) {
409
+ for (int s = 0 ; s < LLAMA_MAX_SEQ ; ++s) {
410
410
if (seq[i].test (s)) {
411
411
seq_pos[s].insert (pos[i]);
412
412
}
0 commit comments