@@ -392,6 +392,14 @@ void llama_kv_cache_unified::set_full() {
392392 head = 0 ;
393393}
394394
395+ bool llama_kv_cache_unified::can_seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) const {
396+ GGML_UNUSED (seq_id);
397+ GGML_UNUSED (p0);
398+ GGML_UNUSED (p1);
399+ // Unified attention cache can always do a sequence removal
400+ return true ;
401+ }
402+
395403llama_sbatch llama_kv_cache_unified::sbatch_init (const llama_batch & batch, bool logits_all) {
396404 return llama_sbatch (batch, hparams.n_embd , true , logits_all);
397405}
@@ -1659,6 +1667,15 @@ void llama_kv_cache_unified_iswa::set_full() {
16591667 kv_swa ->set_full ();
16601668}
16611669
1670+ bool llama_kv_cache_unified_iswa::can_seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) const {
1671+ GGML_UNUSED (seq_id);
1672+ GGML_UNUSED (p0);
1673+ GGML_UNUSED (p1);
1674+ // Unified attention caches can always do a sequence removal, so since both
1675+ // children can, the parent can as well.
1676+ return true ;
1677+ }
1678+
16621679llama_sbatch llama_kv_cache_unified_iswa::sbatch_init (const llama_batch & batch, bool logits_all) {
16631680 pending.clear ();
16641681
@@ -1840,39 +1857,33 @@ void llama_kv_cache_recurrent::clear() {
18401857}
18411858
18421859bool llama_kv_cache_recurrent::seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
1843- uint32_t new_head = size;
1860+ if (!can_seq_rm (seq_id, p0, p1)) {
1861+ // could be fatal
1862+ return false ;
1863+ }
18441864
1865+ uint32_t new_head = size;
18451866 if (p0 < 0 ) {
18461867 p0 = 0 ;
18471868 }
1848-
18491869 if (p1 < 0 ) {
18501870 p1 = std::numeric_limits<llama_pos>::max ();
18511871 }
18521872
1853- // models like Mamba or RWKV can't have a state partially erased
1854- if (seq_id >= (int64_t ) size) {
1855- // could be fatal
1856- return false ;
1857- }
18581873 if (0 <= seq_id) {
18591874 int32_t & tail_id = cells[seq_id].tail ;
18601875 if (tail_id >= 0 ) {
18611876 const kv_cell & cell = cells[tail_id];
1862- // partial intersection is invalid
1863- if ((0 < p0 && p0 <= cell.pos ) || (0 < p1 && p1 <= cell.pos )) {
1864- return false ;
1865- }
1877+ // already validated in can_seq_rm
1878+ GGML_ASSERT (!((0 < p0 && p0 <= cell.pos ) || (0 < p1 && p1 <= cell.pos )));
18661879 // invalidate tails which will be cleared
18671880 if (p0 <= cell.pos && cell.pos < p1) {
18681881 tail_id = -1 ;
18691882 }
18701883 }
18711884 } else {
1872- // seq_id is negative, then the range should include everything or nothing
1873- if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max ())) {
1874- return false ;
1875- }
1885+ // already validated in can_seq_rm
1886+ GGML_ASSERT (!(p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max ())));
18761887 }
18771888
18781889 for (uint32_t i = 0 ; i < size; ++i) {
@@ -2089,6 +2100,34 @@ void llama_kv_cache_recurrent::set_full() {
20892100 n = size;
20902101 head = 0 ;
20912102}
2103+ bool llama_kv_cache_recurrent::can_seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) const {
2104+ if (p0 < 0 ) {
2105+ p0 = 0 ;
2106+ }
2107+
2108+ if (p1 < 0 ) {
2109+ p1 = std::numeric_limits<llama_pos>::max ();
2110+ }
2111+ // models like Mamba or RWKV can't have a state partially erased
2112+ if (seq_id >= (int64_t ) size) {
2113+ // could be fatal
2114+ return false ;
2115+ }
2116+ if (0 <= seq_id) {
2117+ const int32_t & tail_id = cells[seq_id].tail ;
2118+ if (tail_id >= 0 ) {
2119+ const kv_cell & cell = cells[tail_id];
2120+ // partial intersection is invalid
2121+ if ((0 < p0 && p0 <= cell.pos ) || (0 < p1 && p1 <= cell.pos )) {
2122+ return false ;
2123+ }
2124+ }
2125+ // seq_id is negative, then the range should include everything or nothing
2126+ } else if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max ())) {
2127+ return false ;
2128+ }
2129+ return true ;
2130+ }
20922131
20932132llama_sbatch llama_kv_cache_recurrent::sbatch_init (
20942133 const llama_batch & batch,
0 commit comments