Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 20 additions & 27 deletions src/llama-batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,46 +251,39 @@ bool llama_batch_allocr::init(
// consistency checks
//

for (uint32_t s = 0; s < n_seq_max; ++s) {
if (seq_pos[s].empty()) {
continue;
}
// TODO @ngxson : we currently can't check M-RoPE positions, as the position is increased based on image size
if (n_pos_per_embd == 1) {
for (uint32_t s = 0; s < n_seq_max; ++s) {
if (seq_pos[s].empty()) {
continue;
}

const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;

if (p0 >= 0) {
bool ok = true;
if (p0 >= 0) {
bool ok = true;

if (batch.token) {
if (seq_pos_min(s) != p0 + 1) {
ok = false;
}
} else {
assert(batch.embd);

// for embeddings (typically used as vision input), we allow them to have repeating positions
// ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
if (seq_pos_min(s) != p0 && seq_pos_min(s) != p0 + 1) {
ok = false;
if (!ok) {
LLAMA_LOG_ERROR(
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
" it is required that the sequence positions remain consecutive: Y = X + 1\n",
__func__, s, s, p0, s, seq_pos_min(s));

return false;
}
}

if (!ok) {
LLAMA_LOG_ERROR(
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
" it is required that the sequence positions remain consecutive: Y = X + 1\n",
__func__, s, s, p0, s, seq_pos_min(s));

if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
return false;
}
}

if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
return false;
}
}

if (memory) {
Expand Down
4 changes: 4 additions & 0 deletions src/llama-batch.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ struct llama_ubatch {
return b_equal_seqs != 0;
}

bool has_mrope() const {
return data->pos.size() == data->token.size()*4;
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can make this multi-dimensional positional information more decoupled from the concept of rope:

diff --git a/src/llama-batch.h b/src/llama-batch.h
index 34f964ef0..8a6c6daff 100644
--- a/src/llama-batch.h
+++ b/src/llama-batch.h
@@ -17,8 +17,13 @@ struct llama_ubatch {
         return b_equal_seqs != 0;
     }
 
-    bool has_mrope() const {
-        return data->pos.size() == data->token.size()*4;
+    // typical for M-RoPE cases:
+    //   0 - sequantial position of the tokens/embeddings in the sequence
+    //   1 - x position in the image
+    //   2 - y position in the image
+    //   3 - other
+    bool is_pos_2d() const {
+        return n_pos >= 3;
     }
 
     uint32_t b_equal_seqs; // note: this is a boolean, but we use an int32_t for alignment
@@ -29,6 +34,7 @@ struct llama_ubatch {
     uint32_t n_seq_tokens; // tokens per sequence set
     uint32_t n_seqs;       // sequence sets in the ubatch
     uint32_t n_seqs_unq;   // unique sequence ids in the ubatch
+    uint32_t n_pos;        // position inputs for each token/embedding
 
     // seq_id_unq: unique sequence ids in the ubatch
     // seq_idx:    indices of the unique sequence ids in the ubatch in [0, n_seqs_unq)
@@ -37,7 +43,7 @@ struct llama_ubatch {
     //                          // size               | idx | val
     llama_token  *  token;      // [n_tokens]         | i   | id, token
     float        *  embd;       // [n_embd, n_tokens] | i   | embd
-    llama_pos    *  pos;        // [n_tokens]         | i   | pos
+    llama_pos    *  pos;        // [n_tokens*n_pos]   | i   | pos
     int32_t      *  n_seq_id;   // [n_tokens]         | i   | -
     llama_seq_id ** seq_id;     // [n_tokens]         | s   | s0, s1, seq_id
     llama_seq_id *  seq_id_unq; // [n_seqs_unq]       | s   | seq_id

uint32_t b_equal_seqs; // note: this is a boolean, but we use an int32_t for alignment
// otherwise address sanitizer complains
// TODO: whole_seqs for embeddings?
Expand Down
22 changes: 22 additions & 0 deletions src/llama-kv-cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,13 @@ void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch &

cells.pos_set(idx, ubatch.pos[i]);

if (ubatch.has_mrope()) {
cells.pos_mrope_set(idx, {
ubatch.pos[i + ubatch.n_tokens], // y
ubatch.pos[i + ubatch.n_tokens*2], // x
});
}

for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
cells.seq_add(idx, ubatch.seq_id[i][s]);
}
Expand Down Expand Up @@ -1243,6 +1250,13 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u

const llama_pos p1 = ubatch->pos[i];

// for M-RoPE
llama_kv_pos_mrope p1_mrope;
if (ubatch->has_mrope()) {
p1_mrope.y = ubatch->pos[i + ubatch->n_tokens];
p1_mrope.x = ubatch->pos[i + ubatch->n_tokens*2];
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find it confusing to have the order of the positions as y, x. It's more canonical to have the dimensions ordered by increasing significance - x, y, z, .... This is also inline with the ggml convention for indexing.

I now notice that even the implementation of ggml_rope_multi uses this order. I would recommend to update this across the codebase for consistency. Even though it's a breaking change, it's better to do it now, before the mtmd stuff gets more adopted.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes I agree that we should fix the ordering in ggml, I will make a PR for that

Copy link
Collaborator Author

@ngxson ngxson Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm on second thought, I think it cannot be ordered as x,y,z. This is because the full 4D position will be p,x,y,z with p the traditional LLM position

Because Qwen doesn't use the last z dim, so the ordering is currently p,y,x which is decreasing significant.

I think the better way is as you suggest above, decouple the logic into 2d_mrope to be more specific

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, not sure I follow. My point is that p,x,y,t is more consistent order compared to the current p,y,x,t.

const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii);

for (uint32_t j = 0; j < n_kv; ++j) {
Expand All @@ -1262,6 +1276,14 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
continue;
}

// M-RoPE causal mask
if (causal_attn && ubatch->has_mrope() && p0 == p1) {
const auto & p0_mrope = cells.pos_mrope_get(j);
if (p0_mrope.is_gt(p1_mrope)) {
continue;
}
}

// apply SWA if any
if (is_masked_swa(p0, p1)) {
continue;
Expand Down
35 changes: 31 additions & 4 deletions src/llama-kv-cells.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,15 @@
#include <set>
#include <map>

struct llama_kv_pos_mrope {
llama_pos y = 0;
llama_pos x = 0;
// return true if this position is greater than the other position
bool is_gt(const llama_kv_pos_mrope & other) const {
return (y > other.y) || (y == other.y && x > other.x);
}
};

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, I think we can decouple the concept of M-RoPE here by declaring this struct to be more generic:

struct llama_kv_cell_ext {
    // 2D spatial positions, typically used for M-RoPE
    llama_pos x = 0;
    llama_pos y = 0;

    // ... maybe more data in the future
};

// meta information about KV cells that can be part of multiple sequences at the same time
// TODO: add unit tests
class llama_kv_cells {
Expand Down Expand Up @@ -43,6 +52,7 @@ class llama_kv_cells {

void resize(uint32_t n) {
pos.resize(n);
pos_mrope.resize(n);
shift.resize(n);
seq.resize(n);

Expand Down Expand Up @@ -107,8 +117,9 @@ class llama_kv_cells {
for (uint32_t j = 0; j < n; ++j) {
const auto idx = i + j;

res.pos[j] = pos[idx];
res.seq[j] = seq[idx];
res.pos [j] = pos[idx];
res.pos_mrope[j] = pos_mrope[idx];
res.seq [j] = seq[idx];

assert(shift[idx] == 0);
}
Expand All @@ -125,8 +136,9 @@ class llama_kv_cells {
for (uint32_t j = 0; j < idxs.size(); ++j) {
const auto idx = idxs[j];

res.pos[j] = pos[idx];
res.seq[j] = seq[idx];
res.pos [j] = pos[idx];
res.pos_mrope[j] = pos_mrope[idx];
res.seq [j] = seq[idx];

assert(shift[idx] == 0);
}
Expand Down Expand Up @@ -340,6 +352,13 @@ class llama_kv_cells {
return pos[i];
}

const llama_kv_pos_mrope & pos_mrope_get(uint32_t i) const {
assert(i < pos.size());
assert(pos[i] != -1);

return pos_mrope[i];
}

// note: call only if the cell is not empty
llama_pos get_shift(uint32_t i) const {
assert(i < pos.size());
Expand Down Expand Up @@ -368,6 +387,11 @@ class llama_kv_cells {
used.insert(i);
}

void pos_mrope_set(uint32_t i, llama_kv_pos_mrope p) {
assert(i < pos_mrope.size());
pos_mrope[i] = p;
}

// pos[i] = pos[i] + d
// sets "has_shift" to true
// note: call only if the cell is not empty
Expand Down Expand Up @@ -424,6 +448,9 @@ class llama_kv_cells {

std::vector<llama_pos> pos;

// stores addition info for M-RoPE positions
std::vector<llama_kv_pos_mrope> pos_mrope;

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// stores addition info for M-RoPE positions
std::vector<llama_kv_pos_mrope> pos_mrope;
// stores extra optional cell info
std::vector<llama_kv_cell_ext> ext;

// this array accumulates any applied shifts to the pos array since the last reset_shift() call
// this is used to queue multiple updates to the pos array, which in the end can be applied in one go:
//
Expand Down
13 changes: 12 additions & 1 deletion tools/mtmd/mtmd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,15 @@

#include "llama.h"

// fix problem with std::min and std::max
#if defined(_WIN32)
#define WIN32_LEAN_AND_MEAN
#ifndef NOMINMAX
# define NOMINMAX
#endif
#include <windows.h>
#endif

#include <algorithm>
#include <cerrno>
#include <cstdio>
Expand Down Expand Up @@ -1031,7 +1040,9 @@ const char * mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens) {

llama_pos mtmd_image_tokens_get_n_pos(const mtmd_image_tokens * image_tokens) {
if (image_tokens->use_mrope_pos) {
return 1; // for M-RoPE, the whole image is 1 in temporal dimension
// for M-RoPE, temporal dimension = max(t,h,w)
// t is omitted as we don't support video input
return std::max(image_tokens->nx, image_tokens->ny);
}
return image_tokens->n_tokens();
}
Expand Down
4 changes: 2 additions & 2 deletions tools/mtmd/mtmd.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ MTMD_API const mtmd_image_tokens * mtmd_input_chunk_get_tokens_image(const mtmd
MTMD_API size_t mtmd_input_chunk_get_n_tokens (const mtmd_input_chunk * chunk);
// returns nullptr for ID on text chunk
MTMD_API const char * mtmd_input_chunk_get_id (const mtmd_input_chunk * chunk);
// number of temporal positions (always 1 for M-RoPE, n_tokens otherwise)
// number of temporal positions (equals to max(t,h,w) for M-RoPE; equals to n_tokens otherwise)
MTMD_API llama_pos mtmd_input_chunk_get_n_pos (const mtmd_input_chunk * chunk);

// in case you want to use custom logic to handle the chunk (i.e. KV cache management)
Expand All @@ -171,7 +171,7 @@ MTMD_API size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * i
MTMD_API size_t mtmd_image_tokens_get_nx (const mtmd_image_tokens * image_tokens);
MTMD_API size_t mtmd_image_tokens_get_ny (const mtmd_image_tokens * image_tokens);
MTMD_API const char * mtmd_image_tokens_get_id (const mtmd_image_tokens * image_tokens); // TODO: deprecate
// number of temporal positions (always 1 for M-RoPE, n_tokens otherwise)
// number of temporal positions (equals to max(t,h,w) for M-RoPE; equals to n_tokens otherwise)
MTMD_API llama_pos mtmd_image_tokens_get_n_pos (const mtmd_image_tokens * image_tokens); // TODO: deprecate

// tokenize an input text prompt and a list of bitmaps (images/audio)
Expand Down
Loading