Skip to content

Commit c3e1393

Browse files
committed
address review comments
1 parent 90353ea commit c3e1393

File tree

4 files changed

+44
-35
lines changed

4 files changed

+44
-35
lines changed

src/llama-batch.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ bool llama_batch_allocr::init(
215215
/*.n_seq_tokens =*/ (uint32_t) 1,
216216
/*.n_seqs =*/ (uint32_t) batch.n_tokens,
217217
/*.n_seqs_unq =*/ (uint32_t) this->seq_id_unq.size(),
218+
/*.n_pos =*/ n_pos_per_embd,
218219
/*.token =*/ batch.token,
219220
/*.embd =*/ batch.embd,
220221
/*.pos =*/ batch.pos,
@@ -382,6 +383,7 @@ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t
382383
/*.n_seq_tokens =*/ n_seq_tokens,
383384
/*.n_seqs =*/ n_seqs,
384385
/*.n_seqs_unq =*/ n_seqs,
386+
/*.n_pos =*/ n_pos_per_embd,
385387

386388
/*.token =*/ udata->token.data(),
387389
/*.embd =*/ nullptr,
@@ -703,6 +705,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
703705
/*.n_seq_tokens =*/ n_tokens/n_seqs,
704706
/*.n_seqs =*/ n_seqs,
705707
/*.n_seqs_unq =*/ (uint32_t) udata->seq_id_unq.size(),
708+
/*.n_pos =*/ n_pos_per_embd,
706709

707710
/*.token =*/ batch.token ? udata->token.data() : nullptr,
708711
/*.embd =*/ batch.embd ? udata->embd.data() : nullptr,

src/llama-batch.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,14 @@ struct llama_ubatch {
1717
return b_equal_seqs != 0;
1818
}
1919

20-
bool has_mrope() const {
21-
return data->pos.size() == data->token.size()*4;
20+
// typical for M-RoPE cases:
21+
// 0 - sequantial position of the tokens/embeddings in the sequence
22+
// 1 - y position in the image
23+
// 2 - x position in the image
24+
// 3 - other
25+
bool is_pos_2d() const {
26+
// TODO @ngxson : we may need to check for model arch when more models use >1 positions
27+
return n_pos >= 3;
2228
}
2329

2430
uint32_t b_equal_seqs; // note: this is a boolean, but we use an int32_t for alignment
@@ -29,6 +35,7 @@ struct llama_ubatch {
2935
uint32_t n_seq_tokens; // tokens per sequence set
3036
uint32_t n_seqs; // sequence sets in the ubatch
3137
uint32_t n_seqs_unq; // unique sequence ids in the ubatch
38+
uint32_t n_pos; // number of position inputs for each token/embedding
3239

3340
// seq_id_unq: unique sequence ids in the ubatch
3441
// seq_idx: indices of the unique sequence ids in the ubatch in [0, n_seqs_unq)
@@ -37,7 +44,7 @@ struct llama_ubatch {
3744
// // size | idx | val
3845
llama_token * token; // [n_tokens] | i | id, token
3946
float * embd; // [n_embd, n_tokens] | i | embd
40-
llama_pos * pos; // [n_tokens] | i | pos
47+
llama_pos * pos; // [n_tokens*n_pos] | i | pos
4148
int32_t * n_seq_id; // [n_tokens] | i | -
4249
llama_seq_id ** seq_id; // [n_tokens] | s | s0, s1, seq_id
4350
llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id

src/llama-kv-cache.cpp

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -900,11 +900,11 @@ void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch &
900900

901901
cells.pos_set(idx, ubatch.pos[i]);
902902

903-
if (ubatch.has_mrope()) {
904-
cells.pos_mrope_set(idx, {
905-
ubatch.pos[i + ubatch.n_tokens], // y
906-
ubatch.pos[i + ubatch.n_tokens*2], // x
907-
});
903+
if (ubatch.is_pos_2d()) {
904+
llama_kv_cell_ext ext;
905+
ext.x = ubatch.pos[i + ubatch.n_tokens*2];
906+
ext.y = ubatch.pos[i + ubatch.n_tokens];
907+
cells.ext_set(idx, std::move(ext));
908908
}
909909

910910
for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
@@ -1251,11 +1251,8 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
12511251
const llama_pos p1 = ubatch->pos[i];
12521252

12531253
// for M-RoPE
1254-
llama_kv_pos_mrope p1_mrope;
1255-
if (ubatch->has_mrope()) {
1256-
p1_mrope.y = ubatch->pos[i + ubatch->n_tokens];
1257-
p1_mrope.x = ubatch->pos[i + ubatch->n_tokens*2];
1258-
}
1254+
llama_pos p1_x = ubatch->pos[i + ubatch->n_tokens*2];
1255+
llama_pos p1_y = ubatch->pos[i + ubatch->n_tokens];
12591256

12601257
const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii);
12611258

@@ -1277,9 +1274,9 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
12771274
}
12781275

12791276
// M-RoPE causal mask
1280-
if (causal_attn && ubatch->has_mrope() && p0 == p1) {
1281-
const auto & p0_mrope = cells.pos_mrope_get(j);
1282-
if (p0_mrope.is_gt(p1_mrope)) {
1277+
if (causal_attn && ubatch->is_pos_2d() && p0 == p1) {
1278+
const auto & p0_ext = cells.ext_get(j);
1279+
if (p0_ext.is_2d_gt(p1_x, p1_y)) {
12831280
continue;
12841281
}
12851282
}

src/llama-kv-cells.h

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@
99
#include <set>
1010
#include <map>
1111

12-
struct llama_kv_pos_mrope {
13-
llama_pos y = 0;
12+
struct llama_kv_cell_ext {
13+
// 2D spatial positions, typically used for M-RoPE
1414
llama_pos x = 0;
15-
// return true if this position is greater than the other position
16-
bool is_gt(const llama_kv_pos_mrope & other) const {
17-
return (y > other.y) || (y == other.y && x > other.x);
15+
llama_pos y = 0;
16+
17+
// return true if the current 2D spatial position is greater than other
18+
bool is_2d_gt(llama_pos ox, llama_pos oy) const {
19+
return (y > oy) || (y == oy && x > ox);
1820
}
1921
};
2022

@@ -52,7 +54,7 @@ class llama_kv_cells {
5254

5355
void resize(uint32_t n) {
5456
pos.resize(n);
55-
pos_mrope.resize(n);
57+
ext.resize(n);
5658
shift.resize(n);
5759
seq.resize(n);
5860

@@ -117,9 +119,9 @@ class llama_kv_cells {
117119
for (uint32_t j = 0; j < n; ++j) {
118120
const auto idx = i + j;
119121

120-
res.pos [j] = pos[idx];
121-
res.pos_mrope[j] = pos_mrope[idx];
122-
res.seq [j] = seq[idx];
122+
res.pos[j] = pos[idx];
123+
res.ext[j] = ext[idx];
124+
res.seq[j] = seq[idx];
123125

124126
assert(shift[idx] == 0);
125127
}
@@ -136,9 +138,9 @@ class llama_kv_cells {
136138
for (uint32_t j = 0; j < idxs.size(); ++j) {
137139
const auto idx = idxs[j];
138140

139-
res.pos [j] = pos[idx];
140-
res.pos_mrope[j] = pos_mrope[idx];
141-
res.seq [j] = seq[idx];
141+
res.pos[j] = pos[idx];
142+
res.ext[j] = ext[idx];
143+
res.seq[j] = seq[idx];
142144

143145
assert(shift[idx] == 0);
144146
}
@@ -352,11 +354,11 @@ class llama_kv_cells {
352354
return pos[i];
353355
}
354356

355-
const llama_kv_pos_mrope & pos_mrope_get(uint32_t i) const {
357+
const llama_kv_cell_ext & ext_get(uint32_t i) const {
356358
assert(i < pos.size());
357359
assert(pos[i] != -1);
358360

359-
return pos_mrope[i];
361+
return ext[i];
360362
}
361363

362364
// note: call only if the cell is not empty
@@ -387,9 +389,9 @@ class llama_kv_cells {
387389
used.insert(i);
388390
}
389391

390-
void pos_mrope_set(uint32_t i, llama_kv_pos_mrope p) {
391-
assert(i < pos_mrope.size());
392-
pos_mrope[i] = p;
392+
void ext_set(uint32_t i, llama_kv_cell_ext && p) {
393+
assert(i < ext.size());
394+
ext[i] = std::move(p);
393395
}
394396

395397
// pos[i] = pos[i] + d
@@ -448,8 +450,8 @@ class llama_kv_cells {
448450

449451
std::vector<llama_pos> pos;
450452

451-
// stores addition info for M-RoPE positions
452-
std::vector<llama_kv_pos_mrope> pos_mrope;
453+
// stores extra info per cell
454+
std::vector<llama_kv_cell_ext> ext;
453455

454456
// this array accumulates any applied shifts to the pos array since the last reset_shift() call
455457
// this is used to queue multiple updates to the pos array, which in the end can be applied in one go:

0 commit comments

Comments
 (0)