Skip to content

Commit 335d1fd

Browse files
ggerganovqnixsynapse
authored andcommitted
cparams : rename LLAMA_MAX_PARALLEL_SEQUENCES to LLAMA_MAX_SEQ (ggml-org#14188)
ggml-ci
1 parent c4df4a7 commit 335d1fd

File tree

5 files changed

+31
-56
lines changed

5 files changed

+31
-56
lines changed

src/llama-batch.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -289,10 +289,10 @@ llama_batch_allocr::llama_batch_allocr() {
289289
const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG");
290290
debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0;
291291

292-
seq_pos.resize(LLAMA_MAX_PARALLEL_SEQUENCES);
293-
seq_cpl.resize(LLAMA_MAX_PARALLEL_SEQUENCES);
292+
seq_pos.resize(LLAMA_MAX_SEQ);
293+
seq_cpl.resize(LLAMA_MAX_SEQ);
294294
for (auto & cur : seq_cpl) {
295-
cur.resize(LLAMA_MAX_PARALLEL_SEQUENCES);
295+
cur.resize(LLAMA_MAX_SEQ);
296296
}
297297
}
298298

@@ -322,8 +322,8 @@ bool llama_batch_allocr::init(
322322
if (batch.seq_id) {
323323
for (int32_t i = 0; i < batch.n_tokens; ++i) {
324324
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
325-
if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
326-
LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_PARALLEL_SEQUENCES);
325+
if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_SEQ)) {
326+
LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_SEQ);
327327
return false;
328328
}
329329
}
@@ -355,8 +355,8 @@ bool llama_batch_allocr::init(
355355
pos.resize(batch.n_tokens);
356356

357357
// initialize the starting position for each sequence based on the positions in the memory
358-
llama_pos p0[LLAMA_MAX_PARALLEL_SEQUENCES];
359-
for (int32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
358+
llama_pos p0[LLAMA_MAX_SEQ];
359+
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
360360
if (!memory) {
361361
p0[s] = 0;
362362
} else {
@@ -480,7 +480,7 @@ bool llama_batch_allocr::init(
480480
// consistency checks
481481
//
482482

483-
for (int32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
483+
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
484484
if (seq_pos[s].empty()) {
485485
continue;
486486
}
@@ -497,8 +497,8 @@ bool llama_batch_allocr::init(
497497
}
498498

499499
if (memory) {
500-
for (int32_t s0 = 0; s0 < LLAMA_MAX_PARALLEL_SEQUENCES; ++s0) {
501-
for (int32_t s1 = 0; s1 < LLAMA_MAX_PARALLEL_SEQUENCES; ++s1) {
500+
for (int32_t s0 = 0; s0 < LLAMA_MAX_SEQ; ++s0) {
501+
for (int32_t s1 = 0; s1 < LLAMA_MAX_SEQ; ++s1) {
502502
if (seq_cpl[s0][s1]) {
503503
if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
504504
memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {

src/llama-context.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ llama_context::llama_context(
2929
const auto & hparams = model.hparams;
3030

3131
cparams.n_seq_max = std::max(1u, params.n_seq_max);
32-
if (cparams.n_seq_max > LLAMA_MAX_PARALLEL_SEQUENCES) {
33-
throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_PARALLEL_SEQUENCES));
32+
if (cparams.n_seq_max > LLAMA_MAX_SEQ) {
33+
throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ));
3434
}
3535

3636
cparams.n_threads = params.n_threads;
@@ -1023,8 +1023,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
10231023

10241024
if (!res) {
10251025
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
1026-
llama_pos pos_min[LLAMA_MAX_PARALLEL_SEQUENCES];
1027-
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
1026+
llama_pos pos_min[LLAMA_MAX_SEQ];
1027+
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
10281028
pos_min[s] = std::numeric_limits<llama_pos>::max();
10291029
}
10301030

@@ -1035,7 +1035,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
10351035
pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]);
10361036
}
10371037

1038-
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
1038+
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
10391039
if (pos_min[s] == std::numeric_limits<llama_pos>::max()) {
10401040
continue;
10411041
}

src/llama-cparams.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44

55
#include <cstdint>
66

7-
// TODO: rename to something shorter
8-
#define LLAMA_MAX_PARALLEL_SEQUENCES 64
7+
#define LLAMA_MAX_SEQ 64
98

109
struct llama_cparams {
1110
uint32_t n_ctx; // context size used during inference

src/llama-kv-cache-unified.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,7 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
572572
LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
573573
}
574574

575-
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
575+
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
576576
if (cells.seq_pos_min(s) < 0) {
577577
continue;
578578
}
@@ -652,8 +652,8 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
652652

653653
// keep track of the max sequence position that we would overwrite with this ubatch
654654
// for non-SWA cache, this would be always empty
655-
llama_seq_id seq_pos_max_rm[LLAMA_MAX_PARALLEL_SEQUENCES];
656-
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
655+
llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
656+
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
657657
seq_pos_max_rm[s] = -1;
658658
}
659659

@@ -684,7 +684,7 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
684684
// note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
685685
// will be present in the cache. so we have to purge any position which is less than those we would overwrite
686686
// ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
687-
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
687+
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
688688
if (seq_pos_max_rm[s] == -1) {
689689
continue;
690690
}

src/llama-kv-cells.h

Lines changed: 11 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
#include <cassert>
88
#include <vector>
99
#include <set>
10-
#include <map>
1110

1211
// meta information about KV cells that can be part of multiple sequences at the same time
1312
// TODO: add unit tests
@@ -217,7 +216,7 @@ class llama_kv_cells_unified {
217216
assert(seq_id >= 0);
218217

219218
seq[i].reset(seq_id);
220-
seq_pos_dec(seq_id, pos[i]);
219+
seq_pos[seq_id].erase(pos[i]);
221220

222221
if (seq[i].none()) {
223222
pos[i] = -1;
@@ -240,7 +239,7 @@ class llama_kv_cells_unified {
240239
seq[i].reset();
241240

242241
seq[i].set(seq_id);
243-
seq_pos_inc(seq_id, pos[i]);
242+
seq_pos[seq_id].insert(pos[i]);
244243

245244
return false;
246245
}
@@ -285,7 +284,7 @@ class llama_kv_cells_unified {
285284
assert(!seq[i].test(seq_id));
286285

287286
seq[i].set(seq_id);
288-
seq_pos_inc(seq_id, pos[i]);
287+
seq_pos[seq_id].insert(pos[i]);
289288
}
290289

291290
// return the sequence id of this cell
@@ -312,9 +311,7 @@ class llama_kv_cells_unified {
312311
return -1;
313312
}
314313

315-
assert(seq_pos[seq_id].begin()->second > 0);
316-
317-
return seq_pos[seq_id].begin()->first;
314+
return *seq_pos[seq_id].begin();
318315
}
319316

320317
// the maximum position of sequence seq_id currently present in any of the cells
@@ -327,9 +324,7 @@ class llama_kv_cells_unified {
327324
return -1;
328325
}
329326

330-
assert(seq_pos[seq_id].rbegin()->second > 0);
331-
332-
return seq_pos[seq_id].rbegin()->first;
327+
return *seq_pos[seq_id].rbegin();
333328
}
334329

335330
// note: call only if the cell is not empty
@@ -441,41 +436,22 @@ class llama_kv_cells_unified {
441436
//
442437
std::vector<llama_pos> shift;
443438

444-
using seq_set_t = std::bitset<LLAMA_MAX_SEQ>;
439+
using bits_t = std::bitset<LLAMA_MAX_SEQ>;
445440

446441
// the bitset seq[i] tells us which sequences are currently occupying the i-th cell
447-
std::vector<seq_set_t> seq;
442+
std::vector<bits_t> seq;
448443

449-
// the set seq_pos[s][p] tells us how many times the position p is currently present for sequence s
450-
// if the position p is not present, seq_pos[s][p] is not set
444+
// the set seq_pos[s] tells us which positions are currently present for sequence s
451445
// this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
452-
//
453-
// note that we cannot a use an std::set because in some cases a position can occur more than once for the same seq:
454-
// - during performing a cache reuse via (rm + add)
455-
// - some vision models have input embeddings with repeating positions
456-
//
457-
std::map<llama_pos, int> seq_pos[LLAMA_MAX_SEQ];
446+
std::set<llama_pos> seq_pos[LLAMA_MAX_SEQ];
458447

459448
// helper functions for updating `seq_pos`, once cell at a time:
460449

461-
void seq_pos_dec(llama_seq_id s, llama_pos p) {
462-
auto it = seq_pos[s].find(p);
463-
assert(it != seq_pos[s].end());
464-
465-
if (--it->second == 0) {
466-
seq_pos[s].erase(it);
467-
}
468-
}
469-
470-
void seq_pos_inc(llama_seq_id s, llama_pos p) {
471-
seq_pos[s][p]++;
472-
}
473-
474450
// remove cell i
475451
void seq_pos_rm(uint32_t i) {
476452
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
477453
if (seq[i].test(s)) {
478-
seq_pos_dec(s, pos[i]);
454+
seq_pos[s].erase(pos[i]);
479455
}
480456
}
481457
}
@@ -484,7 +460,7 @@ class llama_kv_cells_unified {
484460
void seq_pos_add(uint32_t i) {
485461
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
486462
if (seq[i].test(s)) {
487-
seq_pos_inc(s, pos[i]);
463+
seq_pos[s].insert(pos[i]);
488464
}
489465
}
490466
}

0 commit comments

Comments
 (0)