Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions src/llama-batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,10 +289,10 @@ llama_batch_allocr::llama_batch_allocr() {
const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG");
debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0;

seq_pos.resize(LLAMA_MAX_PARALLEL_SEQUENCES);
seq_cpl.resize(LLAMA_MAX_PARALLEL_SEQUENCES);
seq_pos.resize(LLAMA_MAX_SEQ);
seq_cpl.resize(LLAMA_MAX_SEQ);
for (auto & cur : seq_cpl) {
cur.resize(LLAMA_MAX_PARALLEL_SEQUENCES);
cur.resize(LLAMA_MAX_SEQ);
}
}

Expand Down Expand Up @@ -322,8 +322,8 @@ bool llama_batch_allocr::init(
if (batch.seq_id) {
for (int32_t i = 0; i < batch.n_tokens; ++i) {
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_PARALLEL_SEQUENCES);
if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_SEQ)) {
LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_SEQ);
return false;
}
}
Expand Down Expand Up @@ -355,8 +355,8 @@ bool llama_batch_allocr::init(
pos.resize(batch.n_tokens);

// initialize the starting position for each sequence based on the positions in the memory
llama_pos p0[LLAMA_MAX_PARALLEL_SEQUENCES];
for (int32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
llama_pos p0[LLAMA_MAX_SEQ];
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
if (!memory) {
p0[s] = 0;
} else {
Expand Down Expand Up @@ -480,7 +480,7 @@ bool llama_batch_allocr::init(
// consistency checks
//

for (int32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
if (seq_pos[s].empty()) {
continue;
}
Expand All @@ -497,8 +497,8 @@ bool llama_batch_allocr::init(
}

if (memory) {
for (int32_t s0 = 0; s0 < LLAMA_MAX_PARALLEL_SEQUENCES; ++s0) {
for (int32_t s1 = 0; s1 < LLAMA_MAX_PARALLEL_SEQUENCES; ++s1) {
for (int32_t s0 = 0; s0 < LLAMA_MAX_SEQ; ++s0) {
for (int32_t s1 = 0; s1 < LLAMA_MAX_SEQ; ++s1) {
if (seq_cpl[s0][s1]) {
if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {
Expand Down
10 changes: 5 additions & 5 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ llama_context::llama_context(
const auto & hparams = model.hparams;

cparams.n_seq_max = std::max(1u, params.n_seq_max);
if (cparams.n_seq_max > LLAMA_MAX_PARALLEL_SEQUENCES) {
throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_PARALLEL_SEQUENCES));
if (cparams.n_seq_max > LLAMA_MAX_SEQ) {
throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ));
}

cparams.n_threads = params.n_threads;
Expand Down Expand Up @@ -1023,8 +1023,8 @@ int llama_context::decode(const llama_batch & batch_inp) {

if (!res) {
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
llama_pos pos_min[LLAMA_MAX_PARALLEL_SEQUENCES];
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
llama_pos pos_min[LLAMA_MAX_SEQ];
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
pos_min[s] = std::numeric_limits<llama_pos>::max();
}

Expand All @@ -1035,7 +1035,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]);
}

for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
if (pos_min[s] == std::numeric_limits<llama_pos>::max()) {
continue;
}
Expand Down
2 changes: 1 addition & 1 deletion src/llama-cparams.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#include "llama-cparams.h"

size_t llama_max_parallel_sequences(void) {
return LLAMA_MAX_PARALLEL_SEQUENCES;
return LLAMA_MAX_SEQ;
}
3 changes: 1 addition & 2 deletions src/llama-cparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@

#include <cstdint>

// TODO: rename to something shorter
#define LLAMA_MAX_PARALLEL_SEQUENCES 64
#define LLAMA_MAX_SEQ 64

struct llama_cparams {
uint32_t n_ctx; // context size used during inference
Expand Down
8 changes: 4 additions & 4 deletions src/llama-kv-cache-unified.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
}

for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
if (cells.seq_pos_min(s) < 0) {
continue;
}
Expand Down Expand Up @@ -652,8 +652,8 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch

// keep track of the max sequence position that we would overwrite with this ubatch
// for non-SWA cache, this would be always empty
llama_seq_id seq_pos_max_rm[LLAMA_MAX_PARALLEL_SEQUENCES];
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
seq_pos_max_rm[s] = -1;
}

Expand Down Expand Up @@ -684,7 +684,7 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
// note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
// will be present in the cache. so we have to purge any position which is less than those we would overwrite
// ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
if (seq_pos_max_rm[s] == -1) {
continue;
}
Expand Down
16 changes: 8 additions & 8 deletions src/llama-kv-cells.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class llama_kv_cells_unified {

used.clear();

for (uint32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
seq_pos[s].clear();
}
}
Expand Down Expand Up @@ -240,7 +240,7 @@ class llama_kv_cells_unified {
llama_seq_id seq_get(uint32_t i) const {
assert(seq[i].count() == 1);

for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
if (seq[i].test(s)) {
return s;
}
Expand All @@ -253,7 +253,7 @@ class llama_kv_cells_unified {
// return -1 if the sequence is not present
llama_pos seq_pos_min(llama_seq_id seq_id) const {
assert(seq_id >= 0);
assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES);
assert(seq_id < LLAMA_MAX_SEQ);

if (seq_pos[seq_id].empty()) {
return -1;
Expand All @@ -266,7 +266,7 @@ class llama_kv_cells_unified {
// return -1 if the sequence is not present
llama_pos seq_pos_max(llama_seq_id seq_id) const {
assert(seq_id >= 0);
assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES);
assert(seq_id < LLAMA_MAX_SEQ);

if (seq_pos[seq_id].empty()) {
return -1;
Expand Down Expand Up @@ -384,20 +384,20 @@ class llama_kv_cells_unified {
//
std::vector<llama_pos> shift;

using bits_t = std::bitset<LLAMA_MAX_PARALLEL_SEQUENCES>;
using bits_t = std::bitset<LLAMA_MAX_SEQ>;

// the bitset seq[i] tells us which sequences are currently occupying the i-th cell
std::vector<bits_t> seq;

// the set seq_pos[s] tells us which positions are currently present for sequence s
// this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
std::set<llama_pos> seq_pos[LLAMA_MAX_PARALLEL_SEQUENCES];
std::set<llama_pos> seq_pos[LLAMA_MAX_SEQ];

// helper functions for updating `seq_pos`, once cell at a time:

// remove cell i
void seq_pos_rm(uint32_t i) {
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
if (seq[i].test(s)) {
seq_pos[s].erase(pos[i]);
}
Expand All @@ -406,7 +406,7 @@ class llama_kv_cells_unified {

// add cell i
void seq_pos_add(uint32_t i) {
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
if (seq[i].test(s)) {
seq_pos[s].insert(pos[i]);
}
Expand Down
Loading