Skip to content

Commit e3af556

Browse files
ngxsonggerganov
andauthored
llama: store mrope data in KV cell (#16825)
* llama: store mrope data in KV cell * correct x,y ordering * address review comments * add consistency checks * Update src/llama-kv-cache.cpp Co-authored-by: Georgi Gerganov <[email protected]> * add TODO * fix asan error * kv-cells : improve ext handling * cont : fix headers --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 10fcc41 commit e3af556

File tree

6 files changed

+143
-32
lines changed

6 files changed

+143
-32
lines changed

src/llama-batch.cpp

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ bool llama_batch_allocr::init(
215215
/*.n_seq_tokens =*/ (uint32_t) 1,
216216
/*.n_seqs =*/ (uint32_t) batch.n_tokens,
217217
/*.n_seqs_unq =*/ (uint32_t) this->seq_id_unq.size(),
218+
/*.n_pos =*/ n_pos_per_embd,
218219
/*.token =*/ batch.token,
219220
/*.embd =*/ batch.embd,
220221
/*.pos =*/ batch.pos,
@@ -251,46 +252,58 @@ bool llama_batch_allocr::init(
251252
// consistency checks
252253
//
253254

254-
for (uint32_t s = 0; s < n_seq_max; ++s) {
255-
if (seq_pos[s].empty()) {
256-
continue;
255+
if (n_pos_per_embd > 1) {
256+
// M-RoPE case: allow position to "jump" forward only (non-continuous positions are allowed)
257+
for (uint32_t s = 0; s < n_seq_max; ++s) {
258+
if (seq_pos[s].empty()) {
259+
continue;
260+
}
261+
262+
const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
263+
264+
if (p0 >= 0 && p0 >= seq_pos_min(s)) {
265+
LLAMA_LOG_ERROR(
266+
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
267+
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
268+
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
269+
" for M-RoPE, it is required that the position satisfies: X < Y\n",
270+
__func__, s, s, p0, s, seq_pos_min(s));
271+
272+
return false;
273+
}
257274
}
275+
} else {
276+
for (uint32_t s = 0; s < n_seq_max; ++s) {
277+
if (seq_pos[s].empty()) {
278+
continue;
279+
}
258280

259-
const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
281+
const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
260282

261-
if (p0 >= 0) {
262-
bool ok = true;
283+
if (p0 >= 0) {
284+
bool ok = true;
263285

264-
if (batch.token) {
265286
if (seq_pos_min(s) != p0 + 1) {
266287
ok = false;
267288
}
268-
} else {
269-
assert(batch.embd);
270289

271-
// for embeddings (typically used as vision input), we allow them to have repeating positions
272-
// ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
273-
if (seq_pos_min(s) != p0 && seq_pos_min(s) != p0 + 1) {
274-
ok = false;
290+
if (!ok) {
291+
LLAMA_LOG_ERROR(
292+
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
293+
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
294+
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
295+
" it is required that the sequence positions remain consecutive: Y = X + 1\n",
296+
__func__, s, s, p0, s, seq_pos_min(s));
297+
298+
return false;
275299
}
276300
}
277301

278-
if (!ok) {
279-
LLAMA_LOG_ERROR(
280-
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
281-
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
282-
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
283-
" it is required that the sequence positions remain consecutive: Y = X + 1\n",
284-
__func__, s, s, p0, s, seq_pos_min(s));
285-
302+
if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
303+
LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
286304
return false;
287305
}
288306
}
289-
290-
if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
291-
LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
292-
return false;
293-
}
294307
}
295308

296309
if (memory) {
@@ -389,6 +402,7 @@ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t
389402
/*.n_seq_tokens =*/ n_seq_tokens,
390403
/*.n_seqs =*/ n_seqs,
391404
/*.n_seqs_unq =*/ n_seqs,
405+
/*.n_pos =*/ n_pos_per_embd,
392406

393407
/*.token =*/ udata->token.data(),
394408
/*.embd =*/ nullptr,
@@ -710,6 +724,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
710724
/*.n_seq_tokens =*/ n_tokens/n_seqs,
711725
/*.n_seqs =*/ n_seqs,
712726
/*.n_seqs_unq =*/ (uint32_t) udata->seq_id_unq.size(),
727+
/*.n_pos =*/ n_pos_per_embd,
713728

714729
/*.token =*/ batch.token ? udata->token.data() : nullptr,
715730
/*.embd =*/ batch.embd ? udata->embd.data() : nullptr,

src/llama-batch.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,16 @@ struct llama_ubatch {
1717
return b_equal_seqs != 0;
1818
}
1919

20+
// typical for M-RoPE cases:
21+
// 0 - sequantial position of the tokens/embeddings in the sequence
22+
// 1 - y position in the image
23+
// 2 - x position in the image
24+
// 3 - other
25+
bool is_pos_2d() const {
26+
// TODO @ngxson : we may need to check for model arch when more models use >1 positions
27+
return n_pos >= 3;
28+
}
29+
2030
uint32_t b_equal_seqs; // note: this is a boolean, but we use an int32_t for alignment
2131
// otherwise address sanitizer complains
2232
// TODO: whole_seqs for embeddings?
@@ -25,6 +35,7 @@ struct llama_ubatch {
2535
uint32_t n_seq_tokens; // tokens per sequence set
2636
uint32_t n_seqs; // sequence sets in the ubatch
2737
uint32_t n_seqs_unq; // unique sequence ids in the ubatch
38+
uint32_t n_pos; // number of position inputs for each token/embedding
2839

2940
// seq_id_unq: unique sequence ids in the ubatch
3041
// seq_idx: indices of the unique sequence ids in the ubatch in [0, n_seqs_unq)
@@ -33,7 +44,7 @@ struct llama_ubatch {
3344
// // size | idx | val
3445
llama_token * token; // [n_tokens] | i | id, token
3546
float * embd; // [n_embd, n_tokens] | i | embd
36-
llama_pos * pos; // [n_tokens] | i | pos
47+
llama_pos * pos; // [n_tokens*n_pos] | i | pos
3748
int32_t * n_seq_id; // [n_tokens] | i | -
3849
llama_seq_id ** seq_id; // [n_tokens] | s | s0, s1, seq_id
3950
llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id

src/llama-kv-cache.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,8 @@ void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, ll
338338
llama_pos pos = v_cells[s0].pos_get(i);
339339
llama_pos shift = v_cells[s0].get_shift(i);
340340

341+
llama_kv_cell_ext ext = v_cells[s0].ext_get(i);
342+
341343
if (shift != 0) {
342344
pos -= shift;
343345
assert(pos >= 0);
@@ -349,6 +351,8 @@ void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, ll
349351
if (shift != 0) {
350352
v_cells[s1].pos_add(i, shift);
351353
}
354+
355+
v_cells[s1].ext_set(i, ext);
352356
}
353357
}
354358

@@ -383,6 +387,7 @@ void llama_kv_cache::seq_keep(llama_seq_id seq_id) {
383387

384388
void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
385389
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
390+
GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_add() is only supported for n_pos_per_embd() == 1");
386391

387392
auto & cells = v_cells[seq_to_stream[seq_id]];
388393
auto & head = v_heads[seq_to_stream[seq_id]];
@@ -427,6 +432,7 @@ void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, ll
427432

428433
void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
429434
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
435+
GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_div() is only supported for n_pos_per_embd() == 1");
430436

431437
auto & cells = v_cells[seq_to_stream[seq_id]];
432438

@@ -900,6 +906,14 @@ void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch &
900906

901907
cells.pos_set(idx, ubatch.pos[i]);
902908

909+
if (ubatch.is_pos_2d()) {
910+
llama_kv_cell_ext ext {
911+
/*.x =*/ ubatch.pos[i + ubatch.n_tokens*2],
912+
/*.y =*/ ubatch.pos[i + ubatch.n_tokens],
913+
};
914+
cells.ext_set(idx, ext);
915+
}
916+
903917
for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
904918
cells.seq_add(idx, ubatch.seq_id[i][s]);
905919
}
@@ -1247,6 +1261,11 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
12471261

12481262
const llama_pos p1 = ubatch->pos[i];
12491263

1264+
// for M-RoPE
1265+
const bool is_2d = ubatch->is_pos_2d();
1266+
const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0;
1267+
const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0;
1268+
12501269
const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii);
12511270

12521271
for (uint32_t j = 0; j < n_kv; ++j) {
@@ -1266,6 +1285,14 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
12661285
continue;
12671286
}
12681287

1288+
// M-RoPE causal mask
1289+
if (causal_attn && is_2d && p0 == p1) {
1290+
const auto & p0_ext = cells.ext_get(j);
1291+
if (p0_ext.is_2d_gt(p1_x, p1_y)) {
1292+
continue;
1293+
}
1294+
}
1295+
12691296
// apply SWA if any
12701297
if (is_masked_swa(p0, p1)) {
12711298
continue;
@@ -1559,6 +1586,9 @@ void llama_kv_cache::state_write_meta(llama_io_write_i & io, const cell_ranges_t
15591586
io.write(&pos, sizeof(pos));
15601587
io.write(&n_seq_id, sizeof(n_seq_id));
15611588

1589+
// TODO: we also need to save llama_kv_cell_ext when apply_ubatch() support loading it
1590+
// see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350
1591+
15621592
for (const auto & seq_id : seq_ids) {
15631593
io.write(&seq_id, sizeof(seq_id));
15641594
}
@@ -1704,6 +1734,8 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
17041734
return false;
17051735
}
17061736

1737+
// TODO: we cannot yet restore llama_kv_cell_ext as the apply_ubatch() does not support it yet
1738+
// see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350
17071739
apply_ubatch(sinfo, ubatch);
17081740

17091741
const auto head_cur = sinfo.head();

src/llama-kv-cells.h

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,27 @@
55

66
#include <bitset>
77
#include <cassert>
8-
#include <vector>
9-
#include <set>
8+
#include <cstring>
109
#include <map>
10+
#include <set>
11+
#include <vector>
12+
13+
struct llama_kv_cell_ext {
14+
// 2D spatial positions, typically used for M-RoPE
15+
llama_pos x = 0;
16+
llama_pos y = 0;
17+
18+
// return true if the current 2D spatial position is greater than other
19+
bool is_2d_gt(llama_pos ox, llama_pos oy) const {
20+
return (y > oy) || (y == oy && x > ox);
21+
}
22+
23+
void reset() {
24+
static_assert(std::is_trivially_copyable_v<llama_kv_cell_ext>);
25+
26+
memset(this, 0, sizeof(*this));
27+
}
28+
};
1129

1230
// meta information about KV cells that can be part of multiple sequences at the same time
1331
// TODO: add unit tests
@@ -16,6 +34,7 @@ class llama_kv_cells {
1634
void reset() {
1735
for (uint32_t i = 0; i < pos.size(); ++i) {
1836
pos[i] = -1;
37+
ext[i].reset();
1938
shift[i] = 0;
2039
seq[i].reset();
2140
}
@@ -43,6 +62,7 @@ class llama_kv_cells {
4362

4463
void resize(uint32_t n) {
4564
pos.resize(n);
65+
ext.resize(n);
4666
shift.resize(n);
4767
seq.resize(n);
4868

@@ -108,6 +128,7 @@ class llama_kv_cells {
108128
const auto idx = i + j;
109129

110130
res.pos[j] = pos[idx];
131+
res.ext[j] = ext[idx];
111132
res.seq[j] = seq[idx];
112133

113134
assert(shift[idx] == 0);
@@ -126,6 +147,7 @@ class llama_kv_cells {
126147
const auto idx = idxs[j];
127148

128149
res.pos[j] = pos[idx];
150+
res.ext[j] = ext[idx];
129151
res.seq[j] = seq[idx];
130152

131153
assert(shift[idx] == 0);
@@ -154,6 +176,7 @@ class llama_kv_cells {
154176
}
155177

156178
pos[idx] = other.pos[j];
179+
ext[idx] = other.ext[j];
157180
seq[idx] = other.seq[j];
158181

159182
if (pos[idx] != -1) {
@@ -184,6 +207,7 @@ class llama_kv_cells {
184207
}
185208

186209
pos[idx] = other.pos[j];
210+
ext[idx] = other.ext[j];
187211
seq[idx] = other.seq[j];
188212

189213
if (pos[idx] != -1) {
@@ -203,6 +227,7 @@ class llama_kv_cells {
203227
seq[i].reset();
204228

205229
pos[i] = -1;
230+
ext[i].reset();
206231
shift[i] = 0;
207232

208233
used.erase(i);
@@ -221,6 +246,7 @@ class llama_kv_cells {
221246

222247
if (seq[i].none()) {
223248
pos[i] = -1;
249+
ext[i].reset();
224250
shift[i] = 0;
225251

226252
used.erase(i);
@@ -250,6 +276,7 @@ class llama_kv_cells {
250276
seq[i].reset();
251277

252278
pos[i] = -1;
279+
ext[i].reset();
253280
shift[i] = 0;
254281

255282
used.erase(i);
@@ -340,6 +367,13 @@ class llama_kv_cells {
340367
return pos[i];
341368
}
342369

370+
const llama_kv_cell_ext & ext_get(uint32_t i) const {
371+
assert(i < pos.size());
372+
assert(pos[i] != -1);
373+
374+
return ext[i];
375+
}
376+
343377
// note: call only if the cell is not empty
344378
llama_pos get_shift(uint32_t i) const {
345379
assert(i < pos.size());
@@ -368,6 +402,11 @@ class llama_kv_cells {
368402
used.insert(i);
369403
}
370404

405+
void ext_set(uint32_t i, llama_kv_cell_ext p) {
406+
assert(i < ext.size());
407+
ext[i] = p;
408+
}
409+
371410
// pos[i] = pos[i] + d
372411
// sets "has_shift" to true
373412
// note: call only if the cell is not empty
@@ -424,6 +463,9 @@ class llama_kv_cells {
424463

425464
std::vector<llama_pos> pos;
426465

466+
// stores extra info per cell
467+
std::vector<llama_kv_cell_ext> ext;
468+
427469
// this array accumulates any applied shifts to the pos array since the last reset_shift() call
428470
// this is used to queue multiple updates to the pos array, which in the end can be applied in one go:
429471
//

0 commit comments

Comments
 (0)