Skip to content

Commit ef1af68

Browse files
authored
Merge pull request #31 from rujialiu/mrope-fix
Mrope fix
2 parents 1c1409e + 7db161a commit ef1af68

File tree

5 files changed

+27
-33
lines changed

5 files changed

+27
-33
lines changed

include/llama.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ extern "C" {
232232

233233
llama_token * token;
234234
float * embd;
235-
llama_pos * pos;
235+
llama_pos * pos; // first `n_tokens` elements are always linearly increasing position for traditional llm
236236
int32_t * n_seq_id;
237237
llama_seq_id ** seq_id;
238238
int8_t * logits; // TODO: rename this to "output"

src/llama-batch.cpp

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -259,23 +259,7 @@ bool llama_batch_allocr::init(
259259
const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
260260

261261
if (p0 >= 0) {
262-
bool ok = true;
263-
264-
if (batch.token) {
265-
if (seq_pos_min(s) != p0 + 1) {
266-
ok = false;
267-
}
268-
} else {
269-
assert(batch.embd);
270-
271-
// for embeddings (typically used as vision input), we allow them to have repeating positions
272-
// ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
273-
if (seq_pos_min(s) != p0 && seq_pos_min(s) != p0 + 1) {
274-
ok = false;
275-
}
276-
}
277-
278-
if (!ok) {
262+
if (seq_pos_min(s) != p0 + 1) {
279263
LLAMA_LOG_ERROR(
280264
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
281265
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
@@ -655,7 +639,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
655639

656640
auto udata = std::make_shared<llama_ubatch::data_t>();
657641

658-
const int32_t n_pos_cur = batch.embd ? n_pos_per_embd : 1;
642+
const int32_t n_pos_cur = batch.embd ? (n_pos_per_embd + 1) : 1;
659643

660644
const int64_t n_embd_all = batch.embd ? (int64_t) n_tokens*n_embd : 0;
661645
const int64_t n_pos_all = (int64_t) n_tokens*n_pos_cur;
@@ -681,7 +665,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
681665
}
682666

683667
for (int j = 0; j < n_pos_cur; ++j) {
684-
udata->pos[j*n_tokens + i] = batch.pos[j*batch.n_tokens + idxs[i]];
668+
udata->pos[j * n_tokens + i] = batch.pos[j * batch.n_tokens + idxs[i]];
685669
}
686670

687671
udata->n_seq_id[i] = batch.n_seq_id[idxs[i]];

src/llama-graph.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,13 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
5454
}
5555
ggml_backend_tensor_set(pos, pos_data.data(), 0, pos_data.size()*ggml_element_size(pos));
5656
} else {
57-
ggml_backend_tensor_set(pos, ubatch->pos, 0, n_tokens*n_pos_per_embd*ggml_element_size(pos));
57+
llama_pos * pos_ptr = ubatch->pos;
58+
// Normally, ubatch->pos stores linearly increasing position
59+
// However, some multi-modal models requires special position embedding (e.g. M-Rope in qwen2vl and qwen2.5vl)
60+
// But linearly increasing position is still needed for proper causal attention masking
61+
// So we store both of them: the first n_tokens elements are not changed, while model-specific positions are appended after that.
62+
if (ubatch->embd && n_pos_per_embd > 1) pos_ptr += n_tokens; // use mrope positions
63+
ggml_backend_tensor_set(pos, pos_ptr, 0, n_tokens * n_pos_per_embd * ggml_element_size(pos));
5864
}
5965
}
6066
}

tools/mtmd/mtmd-helper.cpp

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@ llama_pos mtmd_helper_get_n_pos(const mtmd_input_chunks * chunks) {
5555

5656
// helper struct to make working with embd batch easier
5757
// note: this will be removed after llama_batch_ext refactoring
58+
// notes2: Normally, batch's `pos` stores linearly increasing position
59+
// However, some multi-modal models requires special position embedding (e.g. M-Rope in qwen2vl and qwen2.5vl)
60+
// But linearly increasing position is still needed for proper causal attention masking
61+
// So we store both of them: the first n_tokens elements are not changed, while model-specific positions are appended after that.
62+
// So `pos` has `n_tokens * (n_pos_per_embd + 1)` elements
5863
struct decode_embd_batch {
5964
int n_pos_per_embd;
6065
int n_mmproj_embd;
@@ -66,7 +71,7 @@ struct decode_embd_batch {
6671
std::vector<int8_t> logits;
6772
llama_batch batch;
6873
decode_embd_batch(float * embd, int32_t n_tokens, int n_pos_per_embd, int n_mmproj_embd) : n_pos_per_embd(n_pos_per_embd), n_mmproj_embd(n_mmproj_embd) {
69-
pos .resize(n_tokens * n_pos_per_embd);
74+
pos .resize(n_tokens * (n_pos_per_embd + 1));
7075
n_seq_id.resize(n_tokens);
7176
seq_ids .resize(n_tokens + 1);
7277
logits .resize(n_tokens);
@@ -100,13 +105,14 @@ struct decode_embd_batch {
100105
for (int y = 0; y < ny; y++) {
101106
for (int x = 0; x < nx; x++) {
102107
int i = y * nx + x;
103-
pos[i ] = pos_0;
104-
pos[i + batch.n_tokens ] = pos_0 + y;
105-
pos[i + batch.n_tokens * 2] = pos_0 + x;
106-
pos[i + batch.n_tokens * 3] = 0; // last pos dim is unused
108+
pos[i + batch.n_tokens ] = pos_0;
109+
pos[i + batch.n_tokens * 2] = pos_0 + y;
110+
pos[i + batch.n_tokens * 3] = pos_0 + x;
111+
pos[i + batch.n_tokens * 4] = 0; // last pos dim is unused
107112
}
108113
}
109114
for (int i = 0; i < batch.n_tokens; i++) {
115+
batch.pos [i] = pos_0 + i;
110116
batch.n_seq_id[i] = 1;
111117
batch.seq_id [i] = seq_id_0.data();
112118
batch.logits [i] = false;
@@ -118,12 +124,13 @@ struct decode_embd_batch {
118124
GGML_ASSERT(n_pos_per_embd == 4);
119125
seq_id_0[0] = seq_id;
120126
for (int i = 0; i < batch.n_tokens; i++) {
121-
pos[i ] = pos_0 + i;
122127
pos[i + batch.n_tokens ] = pos_0 + i;
123128
pos[i + batch.n_tokens * 2] = pos_0 + i;
124-
pos[i + batch.n_tokens * 3] = 0; // last pos dim is unused
129+
pos[i + batch.n_tokens * 3] = pos_0 + i;
130+
pos[i + batch.n_tokens * 4] = 0; // last pos dim is unused
125131
}
126132
for (int i = 0; i < batch.n_tokens; i++) {
133+
batch.pos [i] = pos_0 + i;
127134
batch.n_seq_id[i] = 1;
128135
batch.seq_id [i] = seq_id_0.data();
129136
batch.logits [i] = false;
@@ -133,12 +140,12 @@ struct decode_embd_batch {
133140
llama_batch get_view(int offset, int n_tokens) {
134141
llama_pos * pos_ptr;
135142
pos_view.clear();
136-
pos_view.reserve(n_tokens * n_pos_per_embd);
143+
pos_view.reserve(n_tokens * (n_pos_per_embd + 1));
137144
if (n_pos_per_embd > 1) {
138145
// mrope
139146
// for example, with layout of src: 1234...1234...1234...1234...
140147
// offset 2 will give us dst: 34...34...34...34...
141-
for (int i = 0; i < n_pos_per_embd; i++) {
148+
for (int i = 0; i <= n_pos_per_embd; i++) {
142149
// assume n_tokens is less than or equal to batch.n_tokens
143150
// batch.n_tokens is number of **total** tokens
144151
// n_tokens is number of viewed token

tools/mtmd/mtmd.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,9 +1030,6 @@ const char * mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens) {
10301030
}
10311031

10321032
llama_pos mtmd_image_tokens_get_n_pos(const mtmd_image_tokens * image_tokens) {
1033-
if (image_tokens->use_mrope_pos) {
1034-
return 1; // for M-RoPE, the whole image is 1 in temporal dimension
1035-
}
10361033
return image_tokens->n_tokens();
10371034
}
10381035

0 commit comments

Comments
 (0)