Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 41 additions & 26 deletions src/llama-batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ bool llama_batch_allocr::init(
/*.n_seq_tokens =*/ (uint32_t) 1,
/*.n_seqs =*/ (uint32_t) batch.n_tokens,
/*.n_seqs_unq =*/ (uint32_t) this->seq_id_unq.size(),
/*.n_pos =*/ n_pos_per_embd,
/*.token =*/ batch.token,
/*.embd =*/ batch.embd,
/*.pos =*/ batch.pos,
Expand Down Expand Up @@ -251,46 +252,58 @@ bool llama_batch_allocr::init(
// consistency checks
//

for (uint32_t s = 0; s < n_seq_max; ++s) {
if (seq_pos[s].empty()) {
continue;
if (n_pos_per_embd > 1) {
// M-RoPE case: allow position to "jump" forward only (non-continuous positions are allowed)
for (uint32_t s = 0; s < n_seq_max; ++s) {
if (seq_pos[s].empty()) {
continue;
}

const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;

if (p0 >= 0 && p0 >= seq_pos_min(s)) {
LLAMA_LOG_ERROR(
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
" for M-RoPE, it is required that the position satisfies: X < Y\n",
__func__, s, s, p0, s, seq_pos_min(s));

return false;
}
}
} else {
for (uint32_t s = 0; s < n_seq_max; ++s) {
if (seq_pos[s].empty()) {
continue;
}

const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;

if (p0 >= 0) {
bool ok = true;
if (p0 >= 0) {
bool ok = true;

if (batch.token) {
if (seq_pos_min(s) != p0 + 1) {
ok = false;
}
} else {
assert(batch.embd);

// for embeddings (typically used as vision input), we allow them to have repeating positions
// ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
if (seq_pos_min(s) != p0 && seq_pos_min(s) != p0 + 1) {
ok = false;
if (!ok) {
LLAMA_LOG_ERROR(
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
" it is required that the sequence positions remain consecutive: Y = X + 1\n",
__func__, s, s, p0, s, seq_pos_min(s));

return false;
}
}

if (!ok) {
LLAMA_LOG_ERROR(
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
" it is required that the sequence positions remain consecutive: Y = X + 1\n",
__func__, s, s, p0, s, seq_pos_min(s));

if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
return false;
}
}

if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
return false;
}
}

if (memory) {
Expand Down Expand Up @@ -389,6 +402,7 @@ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t
/*.n_seq_tokens =*/ n_seq_tokens,
/*.n_seqs =*/ n_seqs,
/*.n_seqs_unq =*/ n_seqs,
/*.n_pos =*/ n_pos_per_embd,

/*.token =*/ udata->token.data(),
/*.embd =*/ nullptr,
Expand Down Expand Up @@ -710,6 +724,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
/*.n_seq_tokens =*/ n_tokens/n_seqs,
/*.n_seqs =*/ n_seqs,
/*.n_seqs_unq =*/ (uint32_t) udata->seq_id_unq.size(),
/*.n_pos =*/ n_pos_per_embd,

/*.token =*/ batch.token ? udata->token.data() : nullptr,
/*.embd =*/ batch.embd ? udata->embd.data() : nullptr,
Expand Down
13 changes: 12 additions & 1 deletion src/llama-batch.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@ struct llama_ubatch {
return b_equal_seqs != 0;
}

// typical for M-RoPE cases:
// 0 - sequantial position of the tokens/embeddings in the sequence
// 1 - y position in the image
// 2 - x position in the image
// 3 - other
bool is_pos_2d() const {
// TODO @ngxson : we may need to check for model arch when more models use >1 positions
return n_pos >= 3;
}

uint32_t b_equal_seqs; // note: this is a boolean, but we use an int32_t for alignment
// otherwise address sanitizer complains
// TODO: whole_seqs for embeddings?
Expand All @@ -25,6 +35,7 @@ struct llama_ubatch {
uint32_t n_seq_tokens; // tokens per sequence set
uint32_t n_seqs; // sequence sets in the ubatch
uint32_t n_seqs_unq; // unique sequence ids in the ubatch
uint32_t n_pos; // number of position inputs for each token/embedding

// seq_id_unq: unique sequence ids in the ubatch
// seq_idx: indices of the unique sequence ids in the ubatch in [0, n_seqs_unq)
Expand All @@ -33,7 +44,7 @@ struct llama_ubatch {
// // size | idx | val
llama_token * token; // [n_tokens] | i | id, token
float * embd; // [n_embd, n_tokens] | i | embd
llama_pos * pos; // [n_tokens] | i | pos
llama_pos * pos; // [n_tokens*n_pos] | i | pos
int32_t * n_seq_id; // [n_tokens] | i | -
llama_seq_id ** seq_id; // [n_tokens] | s | s0, s1, seq_id
llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
Expand Down
19 changes: 19 additions & 0 deletions src/llama-kv-cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,13 @@ void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch &

cells.pos_set(idx, ubatch.pos[i]);

if (ubatch.is_pos_2d()) {
llama_kv_cell_ext ext;
ext.x = ubatch.pos[i + ubatch.n_tokens*2];
ext.y = ubatch.pos[i + ubatch.n_tokens];
cells.ext_set(idx, std::move(ext));
}

for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
cells.seq_add(idx, ubatch.seq_id[i][s]);
}
Expand Down Expand Up @@ -1247,6 +1254,10 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u

const llama_pos p1 = ubatch->pos[i];

// for M-RoPE
llama_pos p1_x = ubatch->pos[i + ubatch->n_tokens*2];
llama_pos p1_y = ubatch->pos[i + ubatch->n_tokens];

const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii);

for (uint32_t j = 0; j < n_kv; ++j) {
Expand All @@ -1266,6 +1277,14 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
continue;
}

// M-RoPE causal mask
if (causal_attn && ubatch->is_pos_2d() && p0 == p1) {
const auto & p0_ext = cells.ext_get(j);
if (p0_ext.is_2d_gt(p1_x, p1_y)) {
continue;
}
}

// apply SWA if any
if (is_masked_swa(p0, p1)) {
continue;
Expand Down
29 changes: 29 additions & 0 deletions src/llama-kv-cells.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,17 @@
#include <set>
#include <map>

struct llama_kv_cell_ext {
// 2D spatial positions, typically used for M-RoPE
llama_pos x = 0;
llama_pos y = 0;

// return true if the current 2D spatial position is greater than other
bool is_2d_gt(llama_pos ox, llama_pos oy) const {
return (y > oy) || (y == oy && x > ox);
}
};

// meta information about KV cells that can be part of multiple sequences at the same time
// TODO: add unit tests
class llama_kv_cells {
Expand Down Expand Up @@ -43,6 +54,7 @@ class llama_kv_cells {

void resize(uint32_t n) {
pos.resize(n);
ext.resize(n);
shift.resize(n);
seq.resize(n);

Expand Down Expand Up @@ -108,6 +120,7 @@ class llama_kv_cells {
const auto idx = i + j;

res.pos[j] = pos[idx];
res.ext[j] = ext[idx];
res.seq[j] = seq[idx];

assert(shift[idx] == 0);
Expand All @@ -126,6 +139,7 @@ class llama_kv_cells {
const auto idx = idxs[j];

res.pos[j] = pos[idx];
res.ext[j] = ext[idx];
res.seq[j] = seq[idx];

assert(shift[idx] == 0);
Expand Down Expand Up @@ -340,6 +354,13 @@ class llama_kv_cells {
return pos[i];
}

const llama_kv_cell_ext & ext_get(uint32_t i) const {
assert(i < pos.size());
assert(pos[i] != -1);

return ext[i];
}

// note: call only if the cell is not empty
llama_pos get_shift(uint32_t i) const {
assert(i < pos.size());
Expand Down Expand Up @@ -368,6 +389,11 @@ class llama_kv_cells {
used.insert(i);
}

void ext_set(uint32_t i, llama_kv_cell_ext && p) {
assert(i < ext.size());
ext[i] = std::move(p);
}

// pos[i] = pos[i] + d
// sets "has_shift" to true
// note: call only if the cell is not empty
Expand Down Expand Up @@ -424,6 +450,9 @@ class llama_kv_cells {

std::vector<llama_pos> pos;

// stores extra info per cell
std::vector<llama_kv_cell_ext> ext;

// this array accumulates any applied shifts to the pos array since the last reset_shift() call
// this is used to queue multiple updates to the pos array, which in the end can be applied in one go:
//
Expand Down
13 changes: 12 additions & 1 deletion tools/mtmd/mtmd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,15 @@

#include "llama.h"

// fix problem with std::min and std::max
#if defined(_WIN32)
#define WIN32_LEAN_AND_MEAN
#ifndef NOMINMAX
# define NOMINMAX
#endif
#include <windows.h>
#endif

#include <algorithm>
#include <cerrno>
#include <cstdio>
Expand Down Expand Up @@ -1031,7 +1040,9 @@ const char * mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens) {

llama_pos mtmd_image_tokens_get_n_pos(const mtmd_image_tokens * image_tokens) {
if (image_tokens->use_mrope_pos) {
return 1; // for M-RoPE, the whole image is 1 in temporal dimension
// for M-RoPE, temporal dimension = max(t,h,w)
// t is omitted as we don't support video input
return std::max(image_tokens->nx, image_tokens->ny);
}
return image_tokens->n_tokens();
}
Expand Down
4 changes: 2 additions & 2 deletions tools/mtmd/mtmd.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ MTMD_API const mtmd_image_tokens * mtmd_input_chunk_get_tokens_image(const mtmd
MTMD_API size_t mtmd_input_chunk_get_n_tokens (const mtmd_input_chunk * chunk);
// returns nullptr for ID on text chunk
MTMD_API const char * mtmd_input_chunk_get_id (const mtmd_input_chunk * chunk);
// number of temporal positions (always 1 for M-RoPE, n_tokens otherwise)
// number of temporal positions (equals to max(t,h,w) for M-RoPE; equals to n_tokens otherwise)
MTMD_API llama_pos mtmd_input_chunk_get_n_pos (const mtmd_input_chunk * chunk);

// in case you want to use custom logic to handle the chunk (i.e. KV cache management)
Expand All @@ -171,7 +171,7 @@ MTMD_API size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * i
MTMD_API size_t mtmd_image_tokens_get_nx (const mtmd_image_tokens * image_tokens);
MTMD_API size_t mtmd_image_tokens_get_ny (const mtmd_image_tokens * image_tokens);
MTMD_API const char * mtmd_image_tokens_get_id (const mtmd_image_tokens * image_tokens); // TODO: deprecate
// number of temporal positions (always 1 for M-RoPE, n_tokens otherwise)
// number of temporal positions (equals to max(t,h,w) for M-RoPE; equals to n_tokens otherwise)
MTMD_API llama_pos mtmd_image_tokens_get_n_pos (const mtmd_image_tokens * image_tokens); // TODO: deprecate

// tokenize an input text prompt and a list of bitmaps (images/audio)
Expand Down
Loading