Skip to content

Commit bed0f57

Browse files
committed
kv-cells : improve ext handling
1 parent 5ec41a1 commit bed0f57

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

src/llama-kv-cache.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,8 @@ void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, ll
338338
llama_pos pos = v_cells[s0].pos_get(i);
339339
llama_pos shift = v_cells[s0].get_shift(i);
340340

341+
llama_kv_cell_ext ext = v_cells[s0].ext_get(i);
342+
341343
if (shift != 0) {
342344
pos -= shift;
343345
assert(pos >= 0);
@@ -349,6 +351,8 @@ void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, ll
349351
if (shift != 0) {
350352
v_cells[s1].pos_add(i, shift);
351353
}
354+
355+
v_cells[s1].ext_set(i, ext);
352356
}
353357
}
354358

@@ -383,6 +387,7 @@ void llama_kv_cache::seq_keep(llama_seq_id seq_id) {
383387

384388
void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
385389
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
390+
GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_add() is only supported for n_pos_per_embd() == 1");
386391

387392
auto & cells = v_cells[seq_to_stream[seq_id]];
388393
auto & head = v_heads[seq_to_stream[seq_id]];
@@ -427,6 +432,7 @@ void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, ll
427432

428433
void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
429434
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
435+
GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_div() is only supported for n_pos_per_embd() == 1");
430436

431437
auto & cells = v_cells[seq_to_stream[seq_id]];
432438

@@ -905,7 +911,7 @@ void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch &
905911
/*.x =*/ ubatch.pos[i + ubatch.n_tokens*2],
906912
/*.y =*/ ubatch.pos[i + ubatch.n_tokens],
907913
};
908-
cells.ext_set(idx, std::move(ext));
914+
cells.ext_set(idx, ext);
909915
}
910916

911917
for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {

src/llama-kv-cells.h

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@ struct llama_kv_cell_ext {
1818
bool is_2d_gt(llama_pos ox, llama_pos oy) const {
1919
return (y > oy) || (y == oy && x > ox);
2020
}
21+
22+
void reset() {
23+
static_assert(std::is_trivially_copyable_v<llama_kv_cell_ext>);
24+
25+
memset(this, 0, sizeof(*this));
26+
}
2127
};
2228

2329
// meta information about KV cells that can be part of multiple sequences at the same time
@@ -27,6 +33,7 @@ class llama_kv_cells {
2733
void reset() {
2834
for (uint32_t i = 0; i < pos.size(); ++i) {
2935
pos[i] = -1;
36+
ext[i].reset();
3037
shift[i] = 0;
3138
seq[i].reset();
3239
}
@@ -168,6 +175,7 @@ class llama_kv_cells {
168175
}
169176

170177
pos[idx] = other.pos[j];
178+
ext[idx] = other.ext[j];
171179
seq[idx] = other.seq[j];
172180

173181
if (pos[idx] != -1) {
@@ -198,6 +206,7 @@ class llama_kv_cells {
198206
}
199207

200208
pos[idx] = other.pos[j];
209+
ext[idx] = other.ext[j];
201210
seq[idx] = other.seq[j];
202211

203212
if (pos[idx] != -1) {
@@ -217,6 +226,7 @@ class llama_kv_cells {
217226
seq[i].reset();
218227

219228
pos[i] = -1;
229+
ext[i].reset();
220230
shift[i] = 0;
221231

222232
used.erase(i);
@@ -235,6 +245,7 @@ class llama_kv_cells {
235245

236246
if (seq[i].none()) {
237247
pos[i] = -1;
248+
ext[i].reset();
238249
shift[i] = 0;
239250

240251
used.erase(i);
@@ -264,6 +275,7 @@ class llama_kv_cells {
264275
seq[i].reset();
265276

266277
pos[i] = -1;
278+
ext[i].reset();
267279
shift[i] = 0;
268280

269281
used.erase(i);
@@ -389,9 +401,9 @@ class llama_kv_cells {
389401
used.insert(i);
390402
}
391403

392-
void ext_set(uint32_t i, llama_kv_cell_ext && p) {
404+
void ext_set(uint32_t i, llama_kv_cell_ext p) {
393405
assert(i < ext.size());
394-
ext[i] = std::move(p);
406+
ext[i] = p;
395407
}
396408

397409
// pos[i] = pos[i] + d

0 commit comments

Comments
 (0)