Skip to content

Commit 228f724

Browse files
authored
kv-cache : fix seq_rm with seq_id == -1 (#15226)
* kv-cache : fix seq_rm with seq_id == -1 ggml-ci * cont : iterate over streams ggml-ci
1 parent cd3069d commit 228f724

File tree

1 file changed

+30
-18
lines changed

1 file changed

+30
-18
lines changed

src/llama-kv-cache-unified.cpp

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -223,12 +223,7 @@ void llama_kv_cache_unified::clear(bool data) {
223223
}
224224

225225
bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
226-
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
227-
228-
auto & cells = v_cells[seq_to_stream[seq_id]];
229-
auto & head = v_heads[seq_to_stream[seq_id]];
230-
231-
uint32_t new_head = cells.size();
226+
GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
232227

233228
if (p0 < 0) {
234229
p0 = 0;
@@ -239,6 +234,11 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
239234
}
240235

241236
if (seq_id >= 0) {
237+
auto & cells = v_cells[seq_to_stream[seq_id]];
238+
auto & head = v_heads[seq_to_stream[seq_id]];
239+
240+
uint32_t new_head = cells.size();
241+
242242
for (uint32_t i = 0; i < cells.size(); ++i) {
243243
if (!cells.pos_in(i, p0, p1)) {
244244
continue;
@@ -250,24 +250,36 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
250250
}
251251
}
252252
}
253+
254+
// If we freed up a slot, set head to it so searching can start there.
255+
if (new_head != cells.size() && new_head < head) {
256+
head = new_head;
257+
}
253258
} else {
254259
// match any sequence
255-
for (uint32_t i = 0; i < cells.size(); ++i) {
256-
if (!cells.pos_in(i, p0, p1)) {
257-
continue;
258-
}
260+
for (uint32_t s = 0; s < n_stream; ++s) {
261+
auto & cells = v_cells[s];
262+
auto & head = v_heads[s];
259263

260-
cells.rm(i);
264+
uint32_t new_head = cells.size();
261265

262-
if (new_head == cells.size()) {
263-
new_head = i;
266+
for (uint32_t i = 0; i < cells.size(); ++i) {
267+
if (!cells.pos_in(i, p0, p1)) {
268+
continue;
269+
}
270+
271+
cells.rm(i);
272+
273+
if (new_head == cells.size()) {
274+
new_head = i;
275+
}
264276
}
265-
}
266-
}
267277

268-
// If we freed up a slot, set head to it so searching can start there.
269-
if (new_head != cells.size() && new_head < head) {
270-
head = new_head;
278+
// If we freed up a slot, set head to it so searching can start there.
279+
if (new_head != cells.size() && new_head < head) {
280+
head = new_head;
281+
}
282+
}
271283
}
272284

273285
return true;

0 commit comments

Comments
 (0)