Skip to content

Commit 3c07909

Browse files
ggerganovqnixsynapse
authored andcommitted
cparams : rename LLAMA_MAX_PARALLEL_SEQUENCES to LLAMA_MAX_SEQ (ggml-org#14188)
ggml-ci
1 parent c94a5fb commit 3c07909

File tree

5 files changed

+31
-56
lines changed

5 files changed

+31
-56
lines changed

src/llama-batch.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -289,10 +289,10 @@ llama_batch_allocr::llama_batch_allocr() {
289289
const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG");
290290
debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0;
291291

292-
seq_pos.resize(LLAMA_MAX_PARALLEL_SEQUENCES);
293-
seq_cpl.resize(LLAMA_MAX_PARALLEL_SEQUENCES);
292+
seq_pos.resize(LLAMA_MAX_SEQ);
293+
seq_cpl.resize(LLAMA_MAX_SEQ);
294294
for (auto & cur : seq_cpl) {
295-
cur.resize(LLAMA_MAX_PARALLEL_SEQUENCES);
295+
cur.resize(LLAMA_MAX_SEQ);
296296
}
297297
}
298298

@@ -322,8 +322,8 @@ bool llama_batch_allocr::init(
322322
if (batch.seq_id) {
323323
for (int32_t i = 0; i < batch.n_tokens; ++i) {
324324
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
325-
if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
326-
LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_PARALLEL_SEQUENCES);
325+
if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_SEQ)) {
326+
LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_SEQ);
327327
return false;
328328
}
329329
}
@@ -355,8 +355,8 @@ bool llama_batch_allocr::init(
355355
pos.resize(batch.n_tokens);
356356

357357
// initialize the starting position for each sequence based on the positions in the memory
358-
llama_pos p0[LLAMA_MAX_PARALLEL_SEQUENCES];
359-
for (int32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
358+
llama_pos p0[LLAMA_MAX_SEQ];
359+
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
360360
if (!memory) {
361361
p0[s] = 0;
362362
} else {
@@ -480,7 +480,7 @@ bool llama_batch_allocr::init(
480480
// consistency checks
481481
//
482482

483-
for (int32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
483+
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
484484
if (seq_pos[s].empty()) {
485485
continue;
486486
}
@@ -497,8 +497,8 @@ bool llama_batch_allocr::init(
497497
}
498498

499499
if (memory) {
500-
for (int32_t s0 = 0; s0 < LLAMA_MAX_PARALLEL_SEQUENCES; ++s0) {
501-
for (int32_t s1 = 0; s1 < LLAMA_MAX_PARALLEL_SEQUENCES; ++s1) {
500+
for (int32_t s0 = 0; s0 < LLAMA_MAX_SEQ; ++s0) {
501+
for (int32_t s1 = 0; s1 < LLAMA_MAX_SEQ; ++s1) {
502502
if (seq_cpl[s0][s1]) {
503503
if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
504504
memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {

src/llama-context.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ llama_context::llama_context(
2929
const auto & hparams = model.hparams;
3030

3131
cparams.n_seq_max = std::max(1u, params.n_seq_max);
32-
if (cparams.n_seq_max > LLAMA_MAX_PARALLEL_SEQUENCES) {
33-
throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_PARALLEL_SEQUENCES));
32+
if (cparams.n_seq_max > LLAMA_MAX_SEQ) {
33+
throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ));
3434
}
3535

3636
cparams.n_threads = params.n_threads;
@@ -1023,8 +1023,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
10231023

10241024
if (!res) {
10251025
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
1026-
llama_pos pos_min[LLAMA_MAX_PARALLEL_SEQUENCES];
1027-
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
1026+
llama_pos pos_min[LLAMA_MAX_SEQ];
1027+
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
10281028
pos_min[s] = std::numeric_limits<llama_pos>::max();
10291029
}
10301030

@@ -1035,7 +1035,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
10351035
pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]);
10361036
}
10371037

1038-
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
1038+
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
10391039
if (pos_min[s] == std::numeric_limits<llama_pos>::max()) {
10401040
continue;
10411041
}

src/llama-cparams.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44

55
#include <cstdint>
66

7-
// TODO: rename to something shorter
8-
#define LLAMA_MAX_PARALLEL_SEQUENCES 64
7+
#define LLAMA_MAX_SEQ 64
98

109
struct llama_cparams {
1110
uint32_t n_ctx; // context size used during inference

src/llama-kv-cache-unified.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,7 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
572572
LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
573573
}
574574

575-
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
575+
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
576576
if (cells.seq_pos_min(s) < 0) {
577577
continue;
578578
}
@@ -652,8 +652,8 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
652652

653653
// keep track of the max sequence position that we would overwrite with this ubatch
654654
// for non-SWA cache, this would be always empty
655-
llama_seq_id seq_pos_max_rm[LLAMA_MAX_PARALLEL_SEQUENCES];
656-
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
655+
llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
656+
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
657657
seq_pos_max_rm[s] = -1;
658658
}
659659

@@ -684,7 +684,7 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
684684
// note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
685685
// will be present in the cache. so we have to purge any position which is less than those we would overwrite
686686
// ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
687-
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
687+
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
688688
if (seq_pos_max_rm[s] == -1) {
689689
continue;
690690
}

src/llama-kv-cells.h

Lines changed: 11 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
#include <cassert>
88
#include <vector>
99
#include <set>
10-
#include <map>
1110

1211
// meta information about KV cells that can be part of multiple sequences at the same time
1312
// TODO: add unit tests
@@ -165,7 +164,7 @@ class llama_kv_cells_unified {
165164
assert(seq_id >= 0);
166165

167166
seq[i].reset(seq_id);
168-
seq_pos_dec(seq_id, pos[i]);
167+
seq_pos[seq_id].erase(pos[i]);
169168

170169
if (seq[i].none()) {
171170
pos[i] = -1;
@@ -188,7 +187,7 @@ class llama_kv_cells_unified {
188187
seq[i].reset();
189188

190189
seq[i].set(seq_id);
191-
seq_pos_inc(seq_id, pos[i]);
190+
seq_pos[seq_id].insert(pos[i]);
192191

193192
return false;
194193
}
@@ -233,7 +232,7 @@ class llama_kv_cells_unified {
233232
assert(!seq[i].test(seq_id));
234233

235234
seq[i].set(seq_id);
236-
seq_pos_inc(seq_id, pos[i]);
235+
seq_pos[seq_id].insert(pos[i]);
237236
}
238237

239238
// return the sequence id of this cell
@@ -260,9 +259,7 @@ class llama_kv_cells_unified {
260259
return -1;
261260
}
262261

263-
assert(seq_pos[seq_id].begin()->second > 0);
264-
265-
return seq_pos[seq_id].begin()->first;
262+
return *seq_pos[seq_id].begin();
266263
}
267264

268265
// the maximum position of sequence seq_id currently present in any of the cells
@@ -275,9 +272,7 @@ class llama_kv_cells_unified {
275272
return -1;
276273
}
277274

278-
assert(seq_pos[seq_id].rbegin()->second > 0);
279-
280-
return seq_pos[seq_id].rbegin()->first;
275+
return *seq_pos[seq_id].rbegin();
281276
}
282277

283278
// note: call only if the cell is not empty
@@ -389,41 +384,22 @@ class llama_kv_cells_unified {
389384
//
390385
std::vector<llama_pos> shift;
391386

392-
using seq_set_t = std::bitset<LLAMA_MAX_SEQ>;
387+
using bits_t = std::bitset<LLAMA_MAX_SEQ>;
393388

394389
// the bitset seq[i] tells us which sequences are currently occupying the i-th cell
395-
std::vector<seq_set_t> seq;
390+
std::vector<bits_t> seq;
396391

397-
// the set seq_pos[s][p] tells us how many times the position p is currently present for sequence s
398-
// if the position p is not present, seq_pos[s][p] is not set
392+
// the set seq_pos[s] tells us which positions are currently present for sequence s
399393
// this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
400-
//
401-
// note that we cannot a use an std::set because in some cases a position can occur more than once for the same seq:
402-
// - during performing a cache reuse via (rm + add)
403-
// - some vision models have input embeddings with repeating positions
404-
//
405-
std::map<llama_pos, int> seq_pos[LLAMA_MAX_SEQ];
394+
std::set<llama_pos> seq_pos[LLAMA_MAX_SEQ];
406395

407396
// helper functions for updating `seq_pos`, once cell at a time:
408397

409-
void seq_pos_dec(llama_seq_id s, llama_pos p) {
410-
auto it = seq_pos[s].find(p);
411-
assert(it != seq_pos[s].end());
412-
413-
if (--it->second == 0) {
414-
seq_pos[s].erase(it);
415-
}
416-
}
417-
418-
void seq_pos_inc(llama_seq_id s, llama_pos p) {
419-
seq_pos[s][p]++;
420-
}
421-
422398
// remove cell i
423399
void seq_pos_rm(uint32_t i) {
424400
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
425401
if (seq[i].test(s)) {
426-
seq_pos_dec(s, pos[i]);
402+
seq_pos[s].erase(pos[i]);
427403
}
428404
}
429405
}
@@ -432,7 +408,7 @@ class llama_kv_cells_unified {
432408
void seq_pos_add(uint32_t i) {
433409
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
434410
if (seq[i].test(s)) {
435-
seq_pos_inc(s, pos[i]);
411+
seq_pos[s].insert(pos[i]);
436412
}
437413
}
438414
}

0 commit comments

Comments
 (0)