Skip to content

Commit 1c25994

Browse files
committed
feat: Add can_seq_rm API to llama_kv_cache API
This will be key for the hybrid cache which needs to be able to validate that all children can perform seq_rm cleanly before attempting to remove the seq from any single child to avoid ending up in a corrupted state. Branch: HybridCache Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 4b45d1a commit 1c25994

File tree

2 files changed

+65
-15
lines changed

2 files changed

+65
-15
lines changed

src/llama-kv-cache.cpp

Lines changed: 54 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,14 @@ void llama_kv_cache_unified::set_full() {
464464
head = 0;
465465
}
466466

467+
bool llama_kv_cache_unified::can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const {
468+
GGML_UNUSED(seq_id);
469+
GGML_UNUSED(p0);
470+
GGML_UNUSED(p1);
471+
// Unified attention cache can always do a sequence removal
472+
return true;
473+
}
474+
467475
llama_sbatch llama_kv_cache_unified::sbatch_init(const llama_batch & batch, bool logits_all) {
468476
return llama_sbatch(batch, hparams.n_embd, true, logits_all);
469477
}
@@ -1747,6 +1755,15 @@ void llama_kv_cache_unified_iswa::set_full() {
17471755
kv_swa ->set_full();
17481756
}
17491757

1758+
bool llama_kv_cache_unified_iswa::can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const {
1759+
GGML_UNUSED(seq_id);
1760+
GGML_UNUSED(p0);
1761+
GGML_UNUSED(p1);
1762+
// Unified attention caches can always do a sequence removal, so since both
1763+
// children can, the parent can as well.
1764+
return true;
1765+
}
1766+
17501767
llama_sbatch llama_kv_cache_unified_iswa::sbatch_init(const llama_batch & batch, bool logits_all) {
17511768
pending.clear();
17521769

@@ -1928,39 +1945,33 @@ void llama_kv_cache_recurrent::clear() {
19281945
}
19291946

19301947
bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
1931-
uint32_t new_head = size;
1948+
if (!can_seq_rm(seq_id, p0, p1)) {
1949+
// could be fatal
1950+
return false;
1951+
}
19321952

1953+
uint32_t new_head = size;
19331954
if (p0 < 0) {
19341955
p0 = 0;
19351956
}
1936-
19371957
if (p1 < 0) {
19381958
p1 = std::numeric_limits<llama_pos>::max();
19391959
}
19401960

1941-
// models like Mamba or RWKV can't have a state partially erased
1942-
if (seq_id >= (int64_t) size) {
1943-
// could be fatal
1944-
return false;
1945-
}
19461961
if (0 <= seq_id) {
19471962
int32_t & tail_id = cells[seq_id].tail;
19481963
if (tail_id >= 0) {
19491964
const kv_cell & cell = cells[tail_id];
1950-
// partial intersection is invalid
1951-
if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
1952-
return false;
1953-
}
1965+
// already validated in can_seq_rm
1966+
GGML_ASSERT(!((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)));
19541967
// invalidate tails which will be cleared
19551968
if (p0 <= cell.pos && cell.pos < p1) {
19561969
tail_id = -1;
19571970
}
19581971
}
19591972
} else {
1960-
// seq_id is negative, then the range should include everything or nothing
1961-
if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
1962-
return false;
1963-
}
1973+
// already validated in can_seq_rm
1974+
GGML_ASSERT(!(p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())));
19641975
}
19651976

19661977
for (uint32_t i = 0; i < size; ++i) {
@@ -2177,6 +2188,34 @@ void llama_kv_cache_recurrent::set_full() {
21772188
n = size;
21782189
head = 0;
21792190
}
2191+
bool llama_kv_cache_recurrent::can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const {
2192+
if (p0 < 0) {
2193+
p0 = 0;
2194+
}
2195+
2196+
if (p1 < 0) {
2197+
p1 = std::numeric_limits<llama_pos>::max();
2198+
}
2199+
// models like Mamba or RWKV can't have a state partially erased
2200+
if (seq_id >= (int64_t) size) {
2201+
// could be fatal
2202+
return false;
2203+
}
2204+
if (0 <= seq_id) {
2205+
const int32_t & tail_id = cells[seq_id].tail;
2206+
if (tail_id >= 0) {
2207+
const kv_cell & cell = cells[tail_id];
2208+
// partial intersection is invalid
2209+
if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
2210+
return false;
2211+
}
2212+
}
2213+
// seq_id is negative, then the range should include everything or nothing
2214+
} else if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
2215+
return false;
2216+
}
2217+
return true;
2218+
}
21802219

21812220
llama_sbatch llama_kv_cache_recurrent::sbatch_init(
21822221
const llama_batch & batch,

src/llama-kv-cache.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ struct llama_kv_cache : public llama_memory_i {
3737
// simulate full cache, used for allocating worst-case compute buffers
3838
virtual void set_full() = 0;
3939

40+
// sometimes it is useful to check whether a cache can remove a sequence
41+
// before attempting to mutate the cache (eg a hybrid cache with multiple
42+
// children to keep in sync)
43+
virtual bool can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const = 0;
44+
4045
//
4146
// batch processing
4247
//
@@ -140,6 +145,8 @@ class llama_kv_cache_unified : public llama_kv_cache {
140145

141146
void set_full() override;
142147

148+
bool can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const override;
149+
143150
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
144151
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
145152

@@ -344,6 +351,8 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache {
344351

345352
void set_full() override;
346353

354+
bool can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const override;
355+
347356
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
348357
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
349358

@@ -450,6 +459,8 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
450459

451460
void set_full() override;
452461

462+
bool can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const override;
463+
453464
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
454465
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
455466

0 commit comments

Comments
 (0)