Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 41 additions & 26 deletions src/llama-batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ bool llama_batch_allocr::init(
/*.n_seq_tokens =*/ (uint32_t) 1,
/*.n_seqs =*/ (uint32_t) batch.n_tokens,
/*.n_seqs_unq =*/ (uint32_t) this->seq_id_unq.size(),
/*.n_pos =*/ n_pos_per_embd,
/*.token =*/ batch.token,
/*.embd =*/ batch.embd,
/*.pos =*/ batch.pos,
Expand Down Expand Up @@ -251,46 +252,58 @@ bool llama_batch_allocr::init(
// consistency checks
//

for (uint32_t s = 0; s < n_seq_max; ++s) {
if (seq_pos[s].empty()) {
continue;
if (n_pos_per_embd > 1) {
// M-RoPE case: allow position to "jump" forward only (non-continuous positions are allowed)
for (uint32_t s = 0; s < n_seq_max; ++s) {
if (seq_pos[s].empty()) {
continue;
}

const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;

if (p0 >= 0 && p0 >= seq_pos_min(s)) {
LLAMA_LOG_ERROR(
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
" for M-RoPE, it is required that the position satisfies: X < Y\n",
__func__, s, s, p0, s, seq_pos_min(s));

return false;
}
}
} else {
for (uint32_t s = 0; s < n_seq_max; ++s) {
if (seq_pos[s].empty()) {
continue;
}

const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;

if (p0 >= 0) {
bool ok = true;
if (p0 >= 0) {
bool ok = true;

if (batch.token) {
if (seq_pos_min(s) != p0 + 1) {
ok = false;
}
} else {
assert(batch.embd);

// for embeddings (typically used as vision input), we allow them to have repeating positions
// ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
if (seq_pos_min(s) != p0 && seq_pos_min(s) != p0 + 1) {
ok = false;
if (!ok) {
LLAMA_LOG_ERROR(
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
" it is required that the sequence positions remain consecutive: Y = X + 1\n",
__func__, s, s, p0, s, seq_pos_min(s));

return false;
}
}

if (!ok) {
LLAMA_LOG_ERROR(
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
" it is required that the sequence positions remain consecutive: Y = X + 1\n",
__func__, s, s, p0, s, seq_pos_min(s));

if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
return false;
}
}

if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
return false;
}
}

if (memory) {
Expand Down Expand Up @@ -389,6 +402,7 @@ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t
/*.n_seq_tokens =*/ n_seq_tokens,
/*.n_seqs =*/ n_seqs,
/*.n_seqs_unq =*/ n_seqs,
/*.n_pos =*/ n_pos_per_embd,

/*.token =*/ udata->token.data(),
/*.embd =*/ nullptr,
Expand Down Expand Up @@ -710,6 +724,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
/*.n_seq_tokens =*/ n_tokens/n_seqs,
/*.n_seqs =*/ n_seqs,
/*.n_seqs_unq =*/ (uint32_t) udata->seq_id_unq.size(),
/*.n_pos =*/ n_pos_per_embd,

/*.token =*/ batch.token ? udata->token.data() : nullptr,
/*.embd =*/ batch.embd ? udata->embd.data() : nullptr,
Expand Down
13 changes: 12 additions & 1 deletion src/llama-batch.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@ struct llama_ubatch {
return b_equal_seqs != 0;
}

// typical for M-RoPE cases:
// 0 - sequantial position of the tokens/embeddings in the sequence
// 1 - y position in the image
// 2 - x position in the image
// 3 - other
bool is_pos_2d() const {
// TODO @ngxson : we may need to check for model arch when more models use >1 positions
return n_pos >= 3;
}

uint32_t b_equal_seqs; // note: this is a boolean, but we use an int32_t for alignment
// otherwise address sanitizer complains
// TODO: whole_seqs for embeddings?
Expand All @@ -25,6 +35,7 @@ struct llama_ubatch {
uint32_t n_seq_tokens; // tokens per sequence set
uint32_t n_seqs; // sequence sets in the ubatch
uint32_t n_seqs_unq; // unique sequence ids in the ubatch
uint32_t n_pos; // number of position inputs for each token/embedding

// seq_id_unq: unique sequence ids in the ubatch
// seq_idx: indices of the unique sequence ids in the ubatch in [0, n_seqs_unq)
Expand All @@ -33,7 +44,7 @@ struct llama_ubatch {
// // size | idx | val
llama_token * token; // [n_tokens] | i | id, token
float * embd; // [n_embd, n_tokens] | i | embd
llama_pos * pos; // [n_tokens] | i | pos
llama_pos * pos; // [n_tokens*n_pos] | i | pos
int32_t * n_seq_id; // [n_tokens] | i | -
llama_seq_id ** seq_id; // [n_tokens] | s | s0, s1, seq_id
llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
Expand Down
26 changes: 26 additions & 0 deletions src/llama-kv-cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,14 @@ void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch &

cells.pos_set(idx, ubatch.pos[i]);

if (ubatch.is_pos_2d()) {
llama_kv_cell_ext ext {
/*.x =*/ ubatch.pos[i + ubatch.n_tokens*2],
/*.y =*/ ubatch.pos[i + ubatch.n_tokens],
};
cells.ext_set(idx, std::move(ext));
}

for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
cells.seq_add(idx, ubatch.seq_id[i][s]);
}
Expand Down Expand Up @@ -1247,6 +1255,11 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u

const llama_pos p1 = ubatch->pos[i];

// for M-RoPE
const bool is_2d = ubatch->is_pos_2d();
const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0;
const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0;

const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii);

for (uint32_t j = 0; j < n_kv; ++j) {
Expand All @@ -1266,6 +1279,14 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
continue;
}

// M-RoPE causal mask
if (causal_attn && is_2d && p0 == p1) {
const auto & p0_ext = cells.ext_get(j);
if (p0_ext.is_2d_gt(p1_x, p1_y)) {
continue;
}
}

// apply SWA if any
if (is_masked_swa(p0, p1)) {
continue;
Expand Down Expand Up @@ -1559,6 +1580,9 @@ void llama_kv_cache::state_write_meta(llama_io_write_i & io, const cell_ranges_t
io.write(&pos, sizeof(pos));
io.write(&n_seq_id, sizeof(n_seq_id));

// TODO: we also need to save llama_kv_cell_ext when apply_ubatch() support loading it
// see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350

for (const auto & seq_id : seq_ids) {
io.write(&seq_id, sizeof(seq_id));
}
Expand Down Expand Up @@ -1704,6 +1728,8 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
return false;
}

// TODO: we cannot yet restore llama_kv_cell_ext as the apply_ubatch() does not support it yet
// see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350
apply_ubatch(sinfo, ubatch);
Comment on lines +1737 to 1739
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand this statement - the apply_ubatch() does handle ext:

if (ubatch.is_pos_2d()) {
llama_kv_cell_ext ext {
/*.x =*/ ubatch.pos[i + ubatch.n_tokens*2],
/*.y =*/ ubatch.pos[i + ubatch.n_tokens],
};
cells.ext_set(idx, std::move(ext));
}

Copy link
Collaborator Author

@ngxson ngxson Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean the ext is constructed from pos, but ideally what I want is that apply_ubatch take the raw ext read from the save file.

The benefit is that when ext has more info than just x,y, then we won't need to update the save/load code again.

Another approach could be the other way: on saving the state, we "serialize" ext back into list of pos that can be later feed into ubatch. But IMO this is a bit hacky.


const auto head_cur = sinfo.head();
Expand Down
29 changes: 29 additions & 0 deletions src/llama-kv-cells.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,17 @@
#include <set>
#include <map>

struct llama_kv_cell_ext {
// 2D spatial positions, typically used for M-RoPE
llama_pos x = 0;
llama_pos y = 0;

// return true if the current 2D spatial position is greater than other
bool is_2d_gt(llama_pos ox, llama_pos oy) const {
return (y > oy) || (y == oy && x > ox);
}
};

// meta information about KV cells that can be part of multiple sequences at the same time
// TODO: add unit tests
class llama_kv_cells {
Expand Down Expand Up @@ -43,6 +54,7 @@ class llama_kv_cells {

void resize(uint32_t n) {
pos.resize(n);
ext.resize(n);
shift.resize(n);
seq.resize(n);

Expand Down Expand Up @@ -108,6 +120,7 @@ class llama_kv_cells {
const auto idx = i + j;

res.pos[j] = pos[idx];
res.ext[j] = ext[idx];
res.seq[j] = seq[idx];

assert(shift[idx] == 0);
Expand All @@ -126,6 +139,7 @@ class llama_kv_cells {
const auto idx = idxs[j];

res.pos[j] = pos[idx];
res.ext[j] = ext[idx];
res.seq[j] = seq[idx];

assert(shift[idx] == 0);
Expand Down Expand Up @@ -340,6 +354,13 @@ class llama_kv_cells {
return pos[i];
}

const llama_kv_cell_ext & ext_get(uint32_t i) const {
assert(i < pos.size());
assert(pos[i] != -1);

return ext[i];
}

// note: call only if the cell is not empty
llama_pos get_shift(uint32_t i) const {
assert(i < pos.size());
Expand Down Expand Up @@ -368,6 +389,11 @@ class llama_kv_cells {
used.insert(i);
}

void ext_set(uint32_t i, llama_kv_cell_ext && p) {
assert(i < ext.size());
ext[i] = std::move(p);
}

// pos[i] = pos[i] + d
// sets "has_shift" to true
// note: call only if the cell is not empty
Expand Down Expand Up @@ -424,6 +450,9 @@ class llama_kv_cells {

std::vector<llama_pos> pos;

// stores extra info per cell
std::vector<llama_kv_cell_ext> ext;

// this array accumulates any applied shifts to the pos array since the last reset_shift() call
// this is used to queue multiple updates to the pos array, which in the end can be applied in one go:
//
Expand Down
13 changes: 12 additions & 1 deletion tools/mtmd/mtmd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,15 @@

#include "llama.h"

// fix problem with std::min and std::max
#if defined(_WIN32)
#define WIN32_LEAN_AND_MEAN
#ifndef NOMINMAX
# define NOMINMAX
#endif
#include <windows.h>
#endif

#include <algorithm>
#include <cerrno>
#include <cstdio>
Expand Down Expand Up @@ -1031,7 +1040,9 @@ const char * mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens) {

llama_pos mtmd_image_tokens_get_n_pos(const mtmd_image_tokens * image_tokens) {
if (image_tokens->use_mrope_pos) {
return 1; // for M-RoPE, the whole image is 1 in temporal dimension
// for M-RoPE, temporal dimension = max(t,h,w)
// t is omitted as we don't support video input
return std::max(image_tokens->nx, image_tokens->ny);
}
return image_tokens->n_tokens();
}
Expand Down
4 changes: 2 additions & 2 deletions tools/mtmd/mtmd.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ MTMD_API const mtmd_image_tokens * mtmd_input_chunk_get_tokens_image(const mtmd
MTMD_API size_t mtmd_input_chunk_get_n_tokens (const mtmd_input_chunk * chunk);
// returns nullptr for ID on text chunk
MTMD_API const char * mtmd_input_chunk_get_id (const mtmd_input_chunk * chunk);
// number of temporal positions (always 1 for M-RoPE, n_tokens otherwise)
// number of temporal positions (equals to max(t,h,w) for M-RoPE; equals to n_tokens otherwise)
MTMD_API llama_pos mtmd_input_chunk_get_n_pos (const mtmd_input_chunk * chunk);

// in case you want to use custom logic to handle the chunk (i.e. KV cache management)
Expand All @@ -171,7 +171,7 @@ MTMD_API size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * i
MTMD_API size_t mtmd_image_tokens_get_nx (const mtmd_image_tokens * image_tokens);
MTMD_API size_t mtmd_image_tokens_get_ny (const mtmd_image_tokens * image_tokens);
MTMD_API const char * mtmd_image_tokens_get_id (const mtmd_image_tokens * image_tokens); // TODO: deprecate
// number of temporal positions (always 1 for M-RoPE, n_tokens otherwise)
// number of temporal positions (equals to max(t,h,w) for M-RoPE; equals to n_tokens otherwise)
MTMD_API llama_pos mtmd_image_tokens_get_n_pos (const mtmd_image_tokens * image_tokens); // TODO: deprecate

// tokenize an input text prompt and a list of bitmaps (images/audio)
Expand Down
Loading