Skip to content

Commit 29be51d

Browse files
committed
feat: Add can_seq_rm API to llama_kv_cache API
This will be key for the hybrid cache which needs to be able to validate that all children can perform seq_rm cleanly before attempting to remove the seq from any single child to avoid ending up in a corrupted state. Branch: HybridCache Signed-off-by: Gabe Goodhart <[email protected]>
1 parent ba118a2 commit 29be51d

File tree

2 files changed

+65
-15
lines changed

2 files changed

+65
-15
lines changed

src/llama-kv-cache.cpp

Lines changed: 54 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,14 @@ void llama_kv_cache_unified::set_full() {
392392
head = 0;
393393
}
394394

395+
bool llama_kv_cache_unified::can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const {
396+
GGML_UNUSED(seq_id);
397+
GGML_UNUSED(p0);
398+
GGML_UNUSED(p1);
399+
// Unified attention cache can always do a sequence removal
400+
return true;
401+
}
402+
395403
llama_sbatch llama_kv_cache_unified::sbatch_init(const llama_batch & batch, bool logits_all) {
396404
return llama_sbatch(batch, hparams.n_embd, true, logits_all);
397405
}
@@ -1659,6 +1667,15 @@ void llama_kv_cache_unified_iswa::set_full() {
16591667
kv_swa ->set_full();
16601668
}
16611669

1670+
bool llama_kv_cache_unified_iswa::can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const {
1671+
GGML_UNUSED(seq_id);
1672+
GGML_UNUSED(p0);
1673+
GGML_UNUSED(p1);
1674+
// Unified attention caches can always do a sequence removal, so since both
1675+
// children can, the parent can as well.
1676+
return true;
1677+
}
1678+
16621679
llama_sbatch llama_kv_cache_unified_iswa::sbatch_init(const llama_batch & batch, bool logits_all) {
16631680
pending.clear();
16641681

@@ -1840,39 +1857,33 @@ void llama_kv_cache_recurrent::clear() {
18401857
}
18411858

18421859
bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
1843-
uint32_t new_head = size;
1860+
if (!can_seq_rm(seq_id, p0, p1)) {
1861+
// could be fatal
1862+
return false;
1863+
}
18441864

1865+
uint32_t new_head = size;
18451866
if (p0 < 0) {
18461867
p0 = 0;
18471868
}
1848-
18491869
if (p1 < 0) {
18501870
p1 = std::numeric_limits<llama_pos>::max();
18511871
}
18521872

1853-
// models like Mamba or RWKV can't have a state partially erased
1854-
if (seq_id >= (int64_t) size) {
1855-
// could be fatal
1856-
return false;
1857-
}
18581873
if (0 <= seq_id) {
18591874
int32_t & tail_id = cells[seq_id].tail;
18601875
if (tail_id >= 0) {
18611876
const kv_cell & cell = cells[tail_id];
1862-
// partial intersection is invalid
1863-
if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
1864-
return false;
1865-
}
1877+
// already validated in can_seq_rm
1878+
GGML_ASSERT(!((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)));
18661879
// invalidate tails which will be cleared
18671880
if (p0 <= cell.pos && cell.pos < p1) {
18681881
tail_id = -1;
18691882
}
18701883
}
18711884
} else {
1872-
// seq_id is negative, then the range should include everything or nothing
1873-
if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
1874-
return false;
1875-
}
1885+
// already validated in can_seq_rm
1886+
GGML_ASSERT(!(p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())));
18761887
}
18771888

18781889
for (uint32_t i = 0; i < size; ++i) {
@@ -2089,6 +2100,34 @@ void llama_kv_cache_recurrent::set_full() {
20892100
n = size;
20902101
head = 0;
20912102
}
2103+
bool llama_kv_cache_recurrent::can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const {
2104+
if (p0 < 0) {
2105+
p0 = 0;
2106+
}
2107+
2108+
if (p1 < 0) {
2109+
p1 = std::numeric_limits<llama_pos>::max();
2110+
}
2111+
// models like Mamba or RWKV can't have a state partially erased
2112+
if (seq_id >= (int64_t) size) {
2113+
// could be fatal
2114+
return false;
2115+
}
2116+
if (0 <= seq_id) {
2117+
const int32_t & tail_id = cells[seq_id].tail;
2118+
if (tail_id >= 0) {
2119+
const kv_cell & cell = cells[tail_id];
2120+
// partial intersection is invalid
2121+
if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
2122+
return false;
2123+
}
2124+
}
2125+
// seq_id is negative, then the range should include everything or nothing
2126+
} else if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
2127+
return false;
2128+
}
2129+
return true;
2130+
}
20922131

20932132
llama_sbatch llama_kv_cache_recurrent::sbatch_init(
20942133
const llama_batch & batch,

src/llama-kv-cache.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ struct llama_kv_cache : public llama_memory_i {
3939
// TODO: remove
4040
virtual void set_full() = 0;
4141

42+
// sometimes it is useful to check whether a cache can remove a sequence
43+
// before attempting to mutate the cache (eg a hybrid cache with multiple
44+
// children to keep in sync)
45+
virtual bool can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const = 0;
46+
4247
//
4348
// batch processing
4449
//
@@ -142,6 +147,8 @@ class llama_kv_cache_unified : public llama_kv_cache {
142147

143148
void set_full() override;
144149

150+
bool can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const override;
151+
145152
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
146153
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
147154

@@ -331,6 +338,8 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache {
331338

332339
void set_full() override;
333340

341+
bool can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const override;
342+
334343
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
335344
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
336345

@@ -437,6 +446,8 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
437446

438447
void set_full() override;
439448

449+
bool can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const override;
450+
440451
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
441452
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
442453

0 commit comments

Comments
 (0)