@@ -81,11 +81,9 @@ llama_kv_cache_unified::llama_kv_cache_unified(
8181 v_cells[s].resize (kv_size);
8282 }
8383
84- // by default, all sequence ids are mapped to the 0th stream
85- seq_to_stream.resize (LLAMA_MAX_SEQ, 0 );
84+ seq_to_stream.resize (n_seq_max, 0 );
8685
8786 if (n_stream > 1 ) {
88- seq_to_stream.resize (n_stream, 0 );
8987 for (uint32_t s = 0 ; s < n_stream; ++s) {
9088 seq_to_stream[s] = s;
9189 }
@@ -223,12 +221,9 @@ void llama_kv_cache_unified::clear(bool data) {
223221}
224222
225223bool llama_kv_cache_unified::seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
226- GGML_ASSERT (seq_id >= 0 && (size_t ) seq_id < seq_to_stream.size ());
227-
228- auto & cells = v_cells[seq_to_stream[seq_id]];
229- auto & head = v_heads[seq_to_stream[seq_id]];
230-
231- uint32_t new_head = cells.size ();
224+ if (seq_id != -1 ) {
225+ GGML_ASSERT (seq_id >= 0 && (size_t ) seq_id < seq_to_stream.size ());
226+ }
232227
233228 if (p0 < 0 ) {
234229 p0 = 0 ;
@@ -239,6 +234,11 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
239234 }
240235
241236 if (seq_id >= 0 ) {
237+ auto & cells = v_cells[seq_to_stream[seq_id]];
238+ auto & head = v_heads[seq_to_stream[seq_id]];
239+
240+ uint32_t new_head = cells.size ();
241+
242242 for (uint32_t i = 0 ; i < cells.size (); ++i) {
243243 if (!cells.pos_in (i, p0, p1)) {
244244 continue ;
@@ -250,24 +250,34 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
250250 }
251251 }
252252 }
253+
254+ if (new_head != cells.size () && new_head < head) {
255+ head = new_head;
256+ }
253257 } else {
254258 // match any sequence
255- for (uint32_t i = 0 ; i < cells.size (); ++i) {
256- if (!cells.pos_in (i, p0, p1)) {
257- continue ;
258- }
259+ for (uint32_t s = 0 ; s < n_stream; ++s) {
260+ auto & cells = v_cells[s];
261+ auto & head = v_heads[s];
259262
260- cells.rm (i );
263+ uint32_t new_head = cells.size ( );
261264
262- if (new_head == cells.size ()) {
263- new_head = i;
265+ for (uint32_t i = 0 ; i < cells.size (); ++i) {
266+ if (!cells.pos_in (i, p0, p1)) {
267+ continue ;
268+ }
269+
270+ cells.rm (i);
271+
272+ if (new_head == cells.size ()) {
273+ new_head = i;
274+ }
264275 }
265- }
266- }
267276
268- // If we freed up a slot, set head to it so searching can start there.
269- if (new_head != cells.size () && new_head < head) {
270- head = new_head;
277+ if (new_head != cells.size () && new_head < head) {
278+ head = new_head;
279+ }
280+ }
271281 }
272282
273283 return true ;
0 commit comments