Skip to content

Commit bf7f924

Browse files
committed
llama: store mrope data in KV cell
1 parent 1c1409e commit bf7f924

File tree

4 files changed

+70
-4
lines changed

4 files changed

+70
-4
lines changed

src/llama-batch.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -660,6 +660,9 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
660660
const int64_t n_embd_all = batch.embd ? (int64_t) n_tokens*n_embd : 0;
661661
const int64_t n_pos_all = (int64_t) n_tokens*n_pos_cur;
662662

663+
// printf("ubatch_add: n_tokens=%d, n_seqs=%d, n_pos_cur=%d, n_embd_all=%lld, n_pos_all=%lld\n",
664+
// n_tokens, n_seqs, n_pos_cur, n_embd_all, n_pos_all);
665+
663666
udata->token .resize(n_tokens);
664667
udata->embd .resize(n_embd_all);
665668
udata->pos .resize(n_pos_all);

src/llama-batch.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ struct llama_ubatch {
1717
return b_equal_seqs != 0;
1818
}
1919

20+
bool has_mrope() const {
21+
return data->pos.size() == data->token.size()*4;
22+
}
23+
2024
uint32_t b_equal_seqs; // note: this is a boolean, but we use an int32_t for alignment
2125
// otherwise address sanitizer complains
2226
// TODO: whole_seqs for embeddings?

src/llama-kv-cache.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -900,6 +900,14 @@ void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch &
900900

901901
cells.pos_set(idx, ubatch.pos[i]);
902902

903+
if (ubatch.has_mrope()) {
904+
cells.pos_mrope_set(idx, {
905+
ubatch.pos[i + ubatch.n_tokens], // x
906+
ubatch.pos[i + ubatch.n_tokens*2], // y
907+
ubatch.pos[i + ubatch.n_tokens*3], // t
908+
});
909+
}
910+
903911
for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
904912
cells.seq_add(idx, ubatch.seq_id[i][s]);
905913
}
@@ -1243,6 +1251,14 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
12431251

12441252
const llama_pos p1 = ubatch->pos[i];
12451253

1254+
// for M-RoPE
1255+
llama_kv_pos_mrope p1_mrope;
1256+
if (ubatch->has_mrope()) {
1257+
p1_mrope.x = ubatch->pos[i + ubatch->n_tokens];
1258+
p1_mrope.y = ubatch->pos[i + ubatch->n_tokens*2];
1259+
p1_mrope.t = ubatch->pos[i + ubatch->n_tokens*3];
1260+
}
1261+
12461262
const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii);
12471263

12481264
for (uint32_t j = 0; j < n_kv; ++j) {
@@ -1262,6 +1278,14 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
12621278
continue;
12631279
}
12641280

1281+
// M-RoPE causal mask
1282+
if (causal_attn && ubatch->has_mrope() && p0 == p1) {
1283+
const auto & p0_mrope = cells.pos_mrope_get(j);
1284+
if (p0_mrope.is_gt(p1_mrope)) {
1285+
continue;
1286+
}
1287+
}
1288+
12651289
// apply SWA if any
12661290
if (is_masked_swa(p0, p1)) {
12671291
continue;

src/llama-kv-cells.h

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,18 @@
99
#include <set>
1010
#include <map>
1111

12+
struct llama_kv_pos_mrope {
13+
llama_pos x;
14+
llama_pos y;
15+
llama_pos t;
16+
// return true if this position is greater than the other position
17+
bool is_gt(const llama_kv_pos_mrope & other) const {
18+
return (t > other.t)
19+
|| (t == other.t && y > other.y)
20+
|| (t == other.t && y == other.y && x > other.x);
21+
}
22+
};
23+
1224
// meta information about KV cells that can be part of multiple sequences at the same time
1325
// TODO: add unit tests
1426
class llama_kv_cells {
@@ -43,6 +55,7 @@ class llama_kv_cells {
4355

4456
void resize(uint32_t n) {
4557
pos.resize(n);
58+
pos_mrope.resize(n);
4659
shift.resize(n);
4760
seq.resize(n);
4861

@@ -107,8 +120,9 @@ class llama_kv_cells {
107120
for (uint32_t j = 0; j < n; ++j) {
108121
const auto idx = i + j;
109122

110-
res.pos[j] = pos[idx];
111-
res.seq[j] = seq[idx];
123+
res.pos [j] = pos[idx];
124+
res.pos_mrope[j] = pos_mrope[idx];
125+
res.seq [j] = seq[idx];
112126

113127
assert(shift[idx] == 0);
114128
}
@@ -125,8 +139,9 @@ class llama_kv_cells {
125139
for (uint32_t j = 0; j < idxs.size(); ++j) {
126140
const auto idx = idxs[j];
127141

128-
res.pos[j] = pos[idx];
129-
res.seq[j] = seq[idx];
142+
res.pos [j] = pos[idx];
143+
res.pos_mrope[j] = pos_mrope[idx];
144+
res.seq [j] = seq[idx];
130145

131146
assert(shift[idx] == 0);
132147
}
@@ -340,6 +355,13 @@ class llama_kv_cells {
340355
return pos[i];
341356
}
342357

358+
const llama_kv_pos_mrope & pos_mrope_get(uint32_t i) const {
359+
assert(i < pos.size());
360+
assert(pos[i] != -1);
361+
362+
return pos_mrope[i];
363+
}
364+
343365
// note: call only if the cell is not empty
344366
llama_pos get_shift(uint32_t i) const {
345367
assert(i < pos.size());
@@ -368,6 +390,16 @@ class llama_kv_cells {
368390
used.insert(i);
369391
}
370392

393+
void pos_mrope_set(uint32_t i, llama_kv_pos_mrope p) {
394+
assert(i < pos.size());
395+
assert(pos[i] == -1);
396+
assert(seq[i].none());
397+
398+
pos_mrope[i] = p;
399+
400+
used.insert(i);
401+
}
402+
371403
// pos[i] = pos[i] + d
372404
// sets "has_shift" to true
373405
// note: call only if the cell is not empty
@@ -424,6 +456,9 @@ class llama_kv_cells {
424456

425457
std::vector<llama_pos> pos;
426458

459+
// stores addition info for M-RoPE positions
460+
std::vector<llama_kv_pos_mrope> pos_mrope;
461+
427462
// this array accumulates any applied shifts to the pos array since the last reset_shift() call
428463
// this is used to queue multiple updates to the pos array, which in the end can be applied in one go:
429464
//

0 commit comments

Comments
 (0)