Skip to content

Commit d406c39

Browse files
ggerganovMinh141120
authored andcommitted
cparams : rename LLAMA_MAX_PARALLEL_SEQUENCES to LLAMA_MAX_SEQ (ggml-org#14188)
ggml-ci
1 parent 4460d36 commit d406c39

File tree

4 files changed

+18
-23
lines changed

4 files changed

+18
-23
lines changed

src/llama-batch.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -627,10 +627,10 @@ llama_batch_allocr::llama_batch_allocr() {
627627
const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG");
628628
debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0;
629629

630-
seq_pos.resize(LLAMA_MAX_PARALLEL_SEQUENCES);
631-
seq_cpl.resize(LLAMA_MAX_PARALLEL_SEQUENCES);
630+
seq_pos.resize(LLAMA_MAX_SEQ);
631+
seq_cpl.resize(LLAMA_MAX_SEQ);
632632
for (auto & cur : seq_cpl) {
633-
cur.resize(LLAMA_MAX_PARALLEL_SEQUENCES);
633+
cur.resize(LLAMA_MAX_SEQ);
634634
}
635635
}
636636

@@ -660,8 +660,8 @@ bool llama_batch_allocr::init(
660660
if (batch.seq_id) {
661661
for (int32_t i = 0; i < batch.n_tokens; ++i) {
662662
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
663-
if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
664-
LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_PARALLEL_SEQUENCES);
663+
if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_SEQ)) {
664+
LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_SEQ);
665665
return false;
666666
}
667667
}
@@ -693,8 +693,8 @@ bool llama_batch_allocr::init(
693693
pos.resize(batch.n_tokens);
694694

695695
// initialize the starting position for each sequence based on the positions in the memory
696-
llama_pos p0[LLAMA_MAX_PARALLEL_SEQUENCES];
697-
for (int32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
696+
llama_pos p0[LLAMA_MAX_SEQ];
697+
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
698698
if (!memory) {
699699
p0[s] = 0;
700700
} else {
@@ -818,7 +818,7 @@ bool llama_batch_allocr::init(
818818
// consistency checks
819819
//
820820

821-
for (int32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
821+
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
822822
if (seq_pos[s].empty()) {
823823
continue;
824824
}
@@ -835,8 +835,8 @@ bool llama_batch_allocr::init(
835835
}
836836

837837
if (memory) {
838-
for (int32_t s0 = 0; s0 < LLAMA_MAX_PARALLEL_SEQUENCES; ++s0) {
839-
for (int32_t s1 = 0; s1 < LLAMA_MAX_PARALLEL_SEQUENCES; ++s1) {
838+
for (int32_t s0 = 0; s0 < LLAMA_MAX_SEQ; ++s0) {
839+
for (int32_t s1 = 0; s1 < LLAMA_MAX_SEQ; ++s1) {
840840
if (seq_cpl[s0][s1]) {
841841
if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
842842
memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {

src/llama-cparams.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44

55
#include <cstdint>
66

7-
// TODO: rename to something shorter
8-
#define LLAMA_MAX_PARALLEL_SEQUENCES 64
7+
#define LLAMA_MAX_SEQ 64
98

109
struct llama_cparams {
1110
uint32_t n_ctx; // context size used during inference

src/llama-kv-cache-unified.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,7 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
603603
LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
604604
}
605605

606-
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
606+
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
607607
if (cells.seq_pos_min(s) < 0) {
608608
continue;
609609
}
@@ -686,8 +686,8 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
686686

687687
// keep track of the max sequence position that we would overwrite with this ubatch
688688
// for non-SWA cache, this would be always empty
689-
llama_seq_id seq_pos_max_rm[LLAMA_MAX_PARALLEL_SEQUENCES];
690-
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
689+
llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
690+
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
691691
seq_pos_max_rm[s] = -1;
692692
}
693693

@@ -718,7 +718,7 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
718718
// note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
719719
// will be present in the cache. so we have to purge any position which is less than those we would overwrite
720720
// ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
721-
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
721+
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
722722
if (seq_pos_max_rm[s] == -1) {
723723
continue;
724724
}

src/llama-kv-cells.h

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class llama_kv_cells_unified {
2424

2525
used.clear();
2626

27+
for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
2728
for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
2829
seq_pos[s].clear();
2930
}
@@ -389,20 +390,15 @@ class llama_kv_cells_unified {
389390
//
390391
std::vector<llama_pos> shift;
391392

392-
using seq_set_t = std::bitset<LLAMA_MAX_SEQ>;
393+
using bits_t = std::bitset<LLAMA_MAX_SEQ>;
393394

394395
// the bitset seq[i] tells us which sequences are currently occupying the i-th cell
395396
std::vector<seq_set_t> seq;
396397

397398
// the set seq_pos[s][p] tells us how many times the position p is currently present for sequence s
398399
// if the position p is not present, seq_pos[s][p] is not set
399400
// this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
400-
//
401-
// note that we cannot a use an std::set because in some cases a position can occur more than once for the same seq:
402-
// - during performing a cache reuse via (rm + add)
403-
// - some vision models have input embeddings with repeating positions
404-
//
405-
std::map<llama_pos, int> seq_pos[LLAMA_MAX_SEQ];
401+
std::set<llama_pos> seq_pos[LLAMA_MAX_SEQ];
406402

407403
// helper functions for updating `seq_pos`, once cell at a time:
408404

0 commit comments

Comments
 (0)