@@ -23,7 +23,7 @@ class llama_kv_cells_unified {
2323
2424 used.clear ();
2525
26- for (uint32_t s = 0 ; s < LLAMA_MAX_PARALLEL_SEQUENCES ; ++s) {
26+ for (uint32_t s = 0 ; s < LLAMA_MAX_SEQ ; ++s) {
2727 seq_pos[s].clear ();
2828 }
2929 }
@@ -240,7 +240,7 @@ class llama_kv_cells_unified {
240240 llama_seq_id seq_get (uint32_t i) const {
241241 assert (seq[i].count () == 1 );
242242
243- for (int s = 0 ; s < LLAMA_MAX_PARALLEL_SEQUENCES ; ++s) {
243+ for (int s = 0 ; s < LLAMA_MAX_SEQ ; ++s) {
244244 if (seq[i].test (s)) {
245245 return s;
246246 }
@@ -253,7 +253,7 @@ class llama_kv_cells_unified {
253253 // return -1 if the sequence is not present
254254 llama_pos seq_pos_min (llama_seq_id seq_id) const {
255255 assert (seq_id >= 0 );
256- assert (seq_id < LLAMA_MAX_PARALLEL_SEQUENCES );
256+ assert (seq_id < LLAMA_MAX_SEQ );
257257
258258 if (seq_pos[seq_id].empty ()) {
259259 return -1 ;
@@ -266,7 +266,7 @@ class llama_kv_cells_unified {
266266 // return -1 if the sequence is not present
267267 llama_pos seq_pos_max (llama_seq_id seq_id) const {
268268 assert (seq_id >= 0 );
269- assert (seq_id < LLAMA_MAX_PARALLEL_SEQUENCES );
269+ assert (seq_id < LLAMA_MAX_SEQ );
270270
271271 if (seq_pos[seq_id].empty ()) {
272272 return -1 ;
@@ -384,20 +384,20 @@ class llama_kv_cells_unified {
384384 //
385385 std::vector<llama_pos> shift;
386386
387- using bits_t = std::bitset<LLAMA_MAX_PARALLEL_SEQUENCES >;
387+ using bits_t = std::bitset<LLAMA_MAX_SEQ >;
388388
389389 // the bitset seq[i] tells us which sequences are currently occupying the i-th cell
390390 std::vector<bits_t > seq;
391391
392392 // the set seq_pos[s] tells us which positions are currently present for sequence s
393393 // this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
394- std::set<llama_pos> seq_pos[LLAMA_MAX_PARALLEL_SEQUENCES ];
394+ std::set<llama_pos> seq_pos[LLAMA_MAX_SEQ ];
395395
396396 // helper functions for updating `seq_pos`, once cell at a time:
397397
398398 // remove cell i
399399 void seq_pos_rm (uint32_t i) {
400- for (int s = 0 ; s < LLAMA_MAX_PARALLEL_SEQUENCES ; ++s) {
400+ for (int s = 0 ; s < LLAMA_MAX_SEQ ; ++s) {
401401 if (seq[i].test (s)) {
402402 seq_pos[s].erase (pos[i]);
403403 }
@@ -406,7 +406,7 @@ class llama_kv_cells_unified {
406406
407407 // add cell i
408408 void seq_pos_add (uint32_t i) {
409- for (int s = 0 ; s < LLAMA_MAX_PARALLEL_SEQUENCES ; ++s) {
409+ for (int s = 0 ; s < LLAMA_MAX_SEQ ; ++s) {
410410 if (seq[i].test (s)) {
411411 seq_pos[s].insert (pos[i]);
412412 }
0 commit comments