Skip to content

Commit 90353ea

Browse files
committed
correct x,y ordering
1 parent bf7f924 commit 90353ea

File tree

5 files changed

+42
-51
lines changed

5 files changed

+42
-51
lines changed

src/llama-batch.cpp

Lines changed: 20 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -251,46 +251,39 @@ bool llama_batch_allocr::init(
251251
// consistency checks
252252
//
253253

254-
for (uint32_t s = 0; s < n_seq_max; ++s) {
255-
if (seq_pos[s].empty()) {
256-
continue;
257-
}
254+
// TODO @ngxson : we currently can't check M-RoPE positions, as the position is increased based on image size
255+
if (n_pos_per_embd == 1) {
256+
for (uint32_t s = 0; s < n_seq_max; ++s) {
257+
if (seq_pos[s].empty()) {
258+
continue;
259+
}
258260

259-
const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
261+
const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
260262

261-
if (p0 >= 0) {
262-
bool ok = true;
263+
if (p0 >= 0) {
264+
bool ok = true;
263265

264-
if (batch.token) {
265266
if (seq_pos_min(s) != p0 + 1) {
266267
ok = false;
267268
}
268-
} else {
269-
assert(batch.embd);
270269

271-
// for embeddings (typically used as vision input), we allow them to have repeating positions
272-
// ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
273-
if (seq_pos_min(s) != p0 && seq_pos_min(s) != p0 + 1) {
274-
ok = false;
270+
if (!ok) {
271+
LLAMA_LOG_ERROR(
272+
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
273+
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
274+
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
275+
" it is required that the sequence positions remain consecutive: Y = X + 1\n",
276+
__func__, s, s, p0, s, seq_pos_min(s));
277+
278+
return false;
275279
}
276280
}
277281

278-
if (!ok) {
279-
LLAMA_LOG_ERROR(
280-
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
281-
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
282-
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
283-
" it is required that the sequence positions remain consecutive: Y = X + 1\n",
284-
__func__, s, s, p0, s, seq_pos_min(s));
285-
282+
if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
283+
LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
286284
return false;
287285
}
288286
}
289-
290-
if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
291-
LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
292-
return false;
293-
}
294287
}
295288

296289
if (memory) {
@@ -660,9 +653,6 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
660653
const int64_t n_embd_all = batch.embd ? (int64_t) n_tokens*n_embd : 0;
661654
const int64_t n_pos_all = (int64_t) n_tokens*n_pos_cur;
662655

663-
// printf("ubatch_add: n_tokens=%d, n_seqs=%d, n_pos_cur=%d, n_embd_all=%lld, n_pos_all=%lld\n",
664-
// n_tokens, n_seqs, n_pos_cur, n_embd_all, n_pos_all);
665-
666656
udata->token .resize(n_tokens);
667657
udata->embd .resize(n_embd_all);
668658
udata->pos .resize(n_pos_all);

src/llama-kv-cache.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -902,9 +902,8 @@ void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch &
902902

903903
if (ubatch.has_mrope()) {
904904
cells.pos_mrope_set(idx, {
905-
ubatch.pos[i + ubatch.n_tokens], // x
906-
ubatch.pos[i + ubatch.n_tokens*2], // y
907-
ubatch.pos[i + ubatch.n_tokens*3], // t
905+
ubatch.pos[i + ubatch.n_tokens], // y
906+
ubatch.pos[i + ubatch.n_tokens*2], // x
908907
});
909908
}
910909

@@ -1254,9 +1253,8 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
12541253
// for M-RoPE
12551254
llama_kv_pos_mrope p1_mrope;
12561255
if (ubatch->has_mrope()) {
1257-
p1_mrope.x = ubatch->pos[i + ubatch->n_tokens];
1258-
p1_mrope.y = ubatch->pos[i + ubatch->n_tokens*2];
1259-
p1_mrope.t = ubatch->pos[i + ubatch->n_tokens*3];
1256+
p1_mrope.y = ubatch->pos[i + ubatch->n_tokens];
1257+
p1_mrope.x = ubatch->pos[i + ubatch->n_tokens*2];
12601258
}
12611259

12621260
const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii);

src/llama-kv-cells.h

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,11 @@
1010
#include <map>
1111

1212
struct llama_kv_pos_mrope {
13-
llama_pos x;
14-
llama_pos y;
15-
llama_pos t;
13+
llama_pos y = 0;
14+
llama_pos x = 0;
1615
// return true if this position is greater than the other position
1716
bool is_gt(const llama_kv_pos_mrope & other) const {
18-
return (t > other.t)
19-
|| (t == other.t && y > other.y)
20-
|| (t == other.t && y == other.y && x > other.x);
17+
return (y > other.y) || (y == other.y && x > other.x);
2118
}
2219
};
2320

@@ -391,13 +388,8 @@ class llama_kv_cells {
391388
}
392389

393390
void pos_mrope_set(uint32_t i, llama_kv_pos_mrope p) {
394-
assert(i < pos.size());
395-
assert(pos[i] == -1);
396-
assert(seq[i].none());
397-
391+
assert(i < pos_mrope.size());
398392
pos_mrope[i] = p;
399-
400-
used.insert(i);
401393
}
402394

403395
// pos[i] = pos[i] + d

tools/mtmd/mtmd.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,15 @@
55

66
#include "llama.h"
77

8+
// fix problem with std::min and std::max
9+
#if defined(_WIN32)
10+
#define WIN32_LEAN_AND_MEAN
11+
#ifndef NOMINMAX
12+
# define NOMINMAX
13+
#endif
14+
#include <windows.h>
15+
#endif
16+
817
#include <algorithm>
918
#include <cerrno>
1019
#include <cstdio>
@@ -1031,7 +1040,9 @@ const char * mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens) {
10311040

10321041
llama_pos mtmd_image_tokens_get_n_pos(const mtmd_image_tokens * image_tokens) {
10331042
if (image_tokens->use_mrope_pos) {
1034-
return 1; // for M-RoPE, the whole image is 1 in temporal dimension
1043+
// for M-RoPE, temporal dimension = max(t,h,w)
1044+
// t is omitted as we don't support video input
1045+
return std::max(image_tokens->nx, image_tokens->ny);
10351046
}
10361047
return image_tokens->n_tokens();
10371048
}

tools/mtmd/mtmd.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ MTMD_API const mtmd_image_tokens * mtmd_input_chunk_get_tokens_image(const mtmd
153153
MTMD_API size_t mtmd_input_chunk_get_n_tokens (const mtmd_input_chunk * chunk);
154154
// returns nullptr for ID on text chunk
155155
MTMD_API const char * mtmd_input_chunk_get_id (const mtmd_input_chunk * chunk);
156-
// number of temporal positions (always 1 for M-RoPE, n_tokens otherwise)
156+
// number of temporal positions (equals to max(t,h,w) for M-RoPE; equals to n_tokens otherwise)
157157
MTMD_API llama_pos mtmd_input_chunk_get_n_pos (const mtmd_input_chunk * chunk);
158158

159159
// in case you want to use custom logic to handle the chunk (i.e. KV cache management)
@@ -171,7 +171,7 @@ MTMD_API size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * i
171171
MTMD_API size_t mtmd_image_tokens_get_nx (const mtmd_image_tokens * image_tokens);
172172
MTMD_API size_t mtmd_image_tokens_get_ny (const mtmd_image_tokens * image_tokens);
173173
MTMD_API const char * mtmd_image_tokens_get_id (const mtmd_image_tokens * image_tokens); // TODO: deprecate
174-
// number of temporal positions (always 1 for M-RoPE, n_tokens otherwise)
174+
// number of temporal positions (equals to max(t,h,w) for M-RoPE; equals to n_tokens otherwise)
175175
MTMD_API llama_pos mtmd_image_tokens_get_n_pos (const mtmd_image_tokens * image_tokens); // TODO: deprecate
176176

177177
// tokenize an input text prompt and a list of bitmaps (images/audio)

0 commit comments

Comments
 (0)