@@ -464,6 +464,14 @@ void llama_kv_cache_unified::set_full() {
464464 head = 0 ;
465465}
466466
467+ bool llama_kv_cache_unified::can_seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) const {
468+ GGML_UNUSED (seq_id);
469+ GGML_UNUSED (p0);
470+ GGML_UNUSED (p1);
471+ // Unified attention cache can always do a sequence removal
472+ return true ;
473+ }
474+
467475llama_sbatch llama_kv_cache_unified::sbatch_init (const llama_batch & batch, bool logits_all) {
468476 return llama_sbatch (batch, hparams.n_embd , true , logits_all);
469477}
@@ -1747,6 +1755,15 @@ void llama_kv_cache_unified_iswa::set_full() {
17471755 kv_swa ->set_full ();
17481756}
17491757
1758+ bool llama_kv_cache_unified_iswa::can_seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) const {
1759+ GGML_UNUSED (seq_id);
1760+ GGML_UNUSED (p0);
1761+ GGML_UNUSED (p1);
1762+ // Unified attention caches can always do a sequence removal, so since both
1763+ // children can, the parent can as well.
1764+ return true ;
1765+ }
1766+
17501767llama_sbatch llama_kv_cache_unified_iswa::sbatch_init (const llama_batch & batch, bool logits_all) {
17511768 pending.clear ();
17521769
@@ -1928,39 +1945,33 @@ void llama_kv_cache_recurrent::clear() {
19281945}
19291946
19301947bool llama_kv_cache_recurrent::seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
1931- uint32_t new_head = size;
1948+ if (!can_seq_rm (seq_id, p0, p1)) {
1949+ // could be fatal
1950+ return false ;
1951+ }
19321952
1953+ uint32_t new_head = size;
19331954 if (p0 < 0 ) {
19341955 p0 = 0 ;
19351956 }
1936-
19371957 if (p1 < 0 ) {
19381958 p1 = std::numeric_limits<llama_pos>::max ();
19391959 }
19401960
1941- // models like Mamba or RWKV can't have a state partially erased
1942- if (seq_id >= (int64_t ) size) {
1943- // could be fatal
1944- return false ;
1945- }
19461961 if (0 <= seq_id) {
19471962 int32_t & tail_id = cells[seq_id].tail ;
19481963 if (tail_id >= 0 ) {
19491964 const kv_cell & cell = cells[tail_id];
1950- // partial intersection is invalid
1951- if ((0 < p0 && p0 <= cell.pos ) || (0 < p1 && p1 <= cell.pos )) {
1952- return false ;
1953- }
1965+ // already validated in can_seq_rm
1966+ GGML_ASSERT (!((0 < p0 && p0 <= cell.pos ) || (0 < p1 && p1 <= cell.pos )));
19541967 // invalidate tails which will be cleared
19551968 if (p0 <= cell.pos && cell.pos < p1) {
19561969 tail_id = -1 ;
19571970 }
19581971 }
19591972 } else {
1960- // seq_id is negative, then the range should include everything or nothing
1961- if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max ())) {
1962- return false ;
1963- }
1973+ // already validated in can_seq_rm
1974+ GGML_ASSERT (!(p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max ())));
19641975 }
19651976
19661977 for (uint32_t i = 0 ; i < size; ++i) {
@@ -2177,6 +2188,34 @@ void llama_kv_cache_recurrent::set_full() {
21772188 n = size;
21782189 head = 0 ;
21792190}
2191+ bool llama_kv_cache_recurrent::can_seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) const {
2192+ if (p0 < 0 ) {
2193+ p0 = 0 ;
2194+ }
2195+
2196+ if (p1 < 0 ) {
2197+ p1 = std::numeric_limits<llama_pos>::max ();
2198+ }
2199+ // models like Mamba or RWKV can't have a state partially erased
2200+ if (seq_id >= (int64_t ) size) {
2201+ // could be fatal
2202+ return false ;
2203+ }
2204+ if (0 <= seq_id) {
2205+ const int32_t & tail_id = cells[seq_id].tail ;
2206+ if (tail_id >= 0 ) {
2207+ const kv_cell & cell = cells[tail_id];
2208+ // partial intersection is invalid
2209+ if ((0 < p0 && p0 <= cell.pos ) || (0 < p1 && p1 <= cell.pos )) {
2210+ return false ;
2211+ }
2212+ }
2213+ // seq_id is negative, then the range should include everything or nothing
2214+ } else if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max ())) {
2215+ return false ;
2216+ }
2217+ return true ;
2218+ }
21802219
21812220llama_sbatch llama_kv_cache_recurrent::sbatch_init (
21822221 const llama_batch & batch,
0 commit comments