Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ extern "C" {
// - token : the token ids of the input (used when embd is NULL)
// - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
// - pos : the positions of the respective token in the sequence
// (for M-RoPE, first `n_tokens` are linearly increasing, followed by `n_pos_per_embd * n_tokens` positions for RoPE)
// (if set to NULL, the token position will be tracked automatically by llama_encode/llama_decode)
// - seq_id : the sequence to which the respective token belongs
// (if set to NULL, the sequence ID will be assumed to be 0)
Expand Down
25 changes: 6 additions & 19 deletions src/llama-batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,23 +259,7 @@ bool llama_batch_allocr::init(
const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;

if (p0 >= 0) {
bool ok = true;

if (batch.token) {
if (seq_pos_min(s) != p0 + 1) {
ok = false;
}
} else {
assert(batch.embd);

// for embeddings (typically used as vision input), we allow them to have repeating positions
// ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
if (seq_pos_min(s) != p0 && seq_pos_min(s) != p0 + 1) {
ok = false;
}
}

if (!ok) {
if (seq_pos_min(s) != p0 + 1) {
LLAMA_LOG_ERROR(
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
Expand Down Expand Up @@ -655,7 +639,10 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u

auto udata = std::make_shared<llama_ubatch::data_t>();

const int32_t n_pos_cur = batch.embd ? n_pos_per_embd : 1;
const int32_t n_pos_per_embd_inp = n_pos_per_embd > 1
? (n_pos_per_embd + 1) // include the extra linearly increasing positions for M-RoPE
: 1; // standard RoPE
const int32_t n_pos_cur = batch.embd ? n_pos_per_embd_inp : 1;

const int64_t n_embd_all = batch.embd ? (int64_t) n_tokens*n_embd : 0;
const int64_t n_pos_all = (int64_t) n_tokens*n_pos_cur;
Expand All @@ -681,7 +668,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
}

for (int j = 0; j < n_pos_cur; ++j) {
udata->pos[j*n_tokens + i] = batch.pos[j*batch.n_tokens + idxs[i]];
udata->pos[j * n_tokens + i] = batch.pos[j * batch.n_tokens + idxs[i]];
}

udata->n_seq_id[i] = batch.n_seq_id[idxs[i]];
Expand Down
5 changes: 4 additions & 1 deletion src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
}
ggml_backend_tensor_set(pos, pos_data.data(), 0, pos_data.size()*ggml_element_size(pos));
} else {
ggml_backend_tensor_set(pos, ubatch->pos, 0, n_tokens*n_pos_per_embd*ggml_element_size(pos));
const bool has_mrope = ubatch->embd && n_pos_per_embd > 1;
ggml_backend_tensor_set(pos,
ubatch->pos + (has_mrope ? n_tokens : 0), // skip the first n_tokens positions for M-RoPE
0, n_tokens*n_pos_per_embd*ggml_element_size(pos));
}
}
}
Expand Down
3 changes: 2 additions & 1 deletion tools/mtmd/mtmd-cli.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ static int generate_response(mtmd_cli_context & ctx, int n_predict) {

// eval the token
common_batch_clear(ctx.batch);
common_batch_add(ctx.batch, token_id, ctx.n_past++, {0}, true);
int max_pos = llama_memory_seq_pos_max(llama_get_memory(ctx.lctx), 0);
common_batch_add(ctx.batch, token_id, max_pos+1, {0}, true);
if (llama_decode(ctx.lctx, ctx.batch)) {
LOG_ERR("failed to decode token\n");
return 1;
Expand Down
101 changes: 73 additions & 28 deletions tools/mtmd/mtmd-helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,16 @@ llama_pos mtmd_helper_get_n_pos(const mtmd_input_chunks * chunks) {
struct decode_embd_batch {
int n_pos_per_embd;
int n_mmproj_embd;
std::vector<llama_pos> pos;
std::vector<llama_pos> pos; // for M-RoPE, this will have (1+n_pos_per_embd)*n_tokens elements
// the extra n_tokens are for linearly increasing positions
std::vector<llama_pos> pos_view; // used by mrope
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id> seq_id_0;
std::vector<llama_seq_id *> seq_ids;
std::vector<int8_t> logits;
llama_batch batch;
decode_embd_batch(float * embd, int32_t n_tokens, int n_pos_per_embd, int n_mmproj_embd) : n_pos_per_embd(n_pos_per_embd), n_mmproj_embd(n_mmproj_embd) {
pos .resize(n_tokens * n_pos_per_embd);
pos .resize(n_tokens * (n_pos_per_embd + 1));
n_seq_id.resize(n_tokens);
seq_ids .resize(n_tokens + 1);
logits .resize(n_tokens);
Expand Down Expand Up @@ -100,13 +101,14 @@ struct decode_embd_batch {
for (int y = 0; y < ny; y++) {
for (int x = 0; x < nx; x++) {
int i = y * nx + x;
pos[i ] = pos_0;
pos[i + batch.n_tokens ] = pos_0 + y;
pos[i + batch.n_tokens * 2] = pos_0 + x;
pos[i + batch.n_tokens * 3] = 0; // last pos dim is unused
pos[i + batch.n_tokens ] = pos_0;
pos[i + batch.n_tokens * 2] = pos_0 + y;
pos[i + batch.n_tokens * 3] = pos_0 + x;
pos[i + batch.n_tokens * 4] = 0; // last pos dim is unused
}
}
for (int i = 0; i < batch.n_tokens; i++) {
batch.pos [i] = pos_0 + i;
batch.n_seq_id[i] = 1;
batch.seq_id [i] = seq_id_0.data();
batch.logits [i] = false;
Expand All @@ -118,12 +120,13 @@ struct decode_embd_batch {
GGML_ASSERT(n_pos_per_embd == 4);
seq_id_0[0] = seq_id;
for (int i = 0; i < batch.n_tokens; i++) {
pos[i ] = pos_0 + i;
pos[i + batch.n_tokens ] = pos_0 + i;
pos[i + batch.n_tokens * 2] = pos_0 + i;
pos[i + batch.n_tokens * 3] = 0; // last pos dim is unused
pos[i + batch.n_tokens * 3] = pos_0 + i;
pos[i + batch.n_tokens * 4] = 0; // last pos dim is unused
}
for (int i = 0; i < batch.n_tokens; i++) {
batch.pos [i] = pos_0 + i;
batch.n_seq_id[i] = 1;
batch.seq_id [i] = seq_id_0.data();
batch.logits [i] = false;
Expand All @@ -133,12 +136,12 @@ struct decode_embd_batch {
llama_batch get_view(int offset, int n_tokens) {
llama_pos * pos_ptr;
pos_view.clear();
pos_view.reserve(n_tokens * n_pos_per_embd);
pos_view.reserve(n_tokens * (n_pos_per_embd + 1));
if (n_pos_per_embd > 1) {
// mrope
// for example, with layout of src: 1234...1234...1234...1234...
// offset 2 will give us dst: 34...34...34...34...
for (int i = 0; i < n_pos_per_embd; i++) {
for (int i = 0; i <= n_pos_per_embd; i++) {
// assume n_tokens is less than or equal to batch.n_tokens
// batch.n_tokens is number of **total** tokens
// n_tokens is number of viewed token
Expand All @@ -164,6 +167,59 @@ struct decode_embd_batch {
}
};

// helper struct to make working with embd batch easier
struct decode_text_batch {
std::vector<llama_token> tokens;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id> seq_id_0;
std::vector<llama_seq_id *> seq_ids;
std::vector<int8_t> logits;
llama_seq_id seq_id;
llama_batch batch;
decode_text_batch(int32_t n_tokens, llama_seq_id seq_id) : seq_id(seq_id) {
tokens .resize(n_tokens);
n_seq_id.resize(n_tokens);
seq_ids .resize(n_tokens + 1);
logits .resize(n_tokens);
seq_ids[n_tokens] = nullptr;
for (int32_t i = 0; i < n_tokens; i++) {
n_seq_id[i] = 1;
seq_ids [i] = &this->seq_id;
}
batch = {
/*n_tokens =*/ n_tokens,
/*tokens =*/ tokens.data(),
/*embd =*/ nullptr,
/*pos =*/ nullptr, // position is tracked automatically
/*n_seq_id =*/ n_seq_id.data(),
/*seq_id =*/ seq_ids.data(),
/*logits =*/ logits.data(),
};
}

void clear() {
batch.n_tokens = 0;
}

bool is_full() const {
return batch.n_tokens >= (int32_t) tokens.size();
}

void add_token(llama_token tok, bool output) {
GGML_ASSERT(!is_full());
int32_t j = batch.n_tokens;
batch.token [j] = tok;
batch.logits[j] = output;
batch.n_tokens++;
}

void set_logits_last() {
if (batch.n_tokens > 0) {
batch.logits[batch.n_tokens - 1] = true;
}
}
};

// Helper function for decoding an image whose embeddings have already been calculated
int32_t mtmd_helper_decode_image_chunk(
mtmd_context * ctx,
Expand Down Expand Up @@ -252,7 +308,7 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
bool logits_last,
llama_pos * new_n_past) {
int32_t ret;
llama_batch text_batch = llama_batch_init(n_batch, 0, 1);
decode_text_batch text_batch(n_batch, seq_id);
auto chunk_type = mtmd_input_chunk_get_type(chunk);

if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
Expand All @@ -261,28 +317,20 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
// LOG_INF("decoding text chunk, n_tokens = %zu\n", n_tokens);
size_t i = 0;
while (i < n_tokens) { // split into batches
text_batch.n_tokens = 0; // clear the batch
for (; i < n_tokens && text_batch.n_tokens < n_batch; i++) {
int32_t j = text_batch.n_tokens;
text_batch.token [j] = tokens[i];
text_batch.pos [j] = n_past++;
text_batch.n_seq_id[j] = 1;
text_batch.seq_id [j][0] = seq_id;
text_batch.logits [j] = false;

text_batch.n_tokens++;
text_batch.clear();
for (; i < n_tokens && !text_batch.is_full(); i++) {
text_batch.add_token(tokens[i], false);
}
bool is_last_token = (i == n_tokens);
if (logits_last && is_last_token) {
text_batch.logits[text_batch.n_tokens - 1] = true;
text_batch.set_logits_last();
}
ret = llama_decode(lctx, text_batch);
ret = llama_decode(lctx, text_batch.batch);
if (ret != 0) {
LOG_ERR("failed to decode text\n");
llama_batch_free(text_batch);
return ret;
}
*new_n_past += text_batch.n_tokens;
*new_n_past += text_batch.batch.n_tokens;
}

} else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE || chunk_type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
Expand All @@ -294,7 +342,6 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
ret = mtmd_encode_chunk(ctx, chunk);
if (ret != 0) {
LOG_ERR("failed to encode %s slice\n", name);
llama_batch_free(text_batch);
return ret;
}

Expand All @@ -304,14 +351,12 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
ret = mtmd_helper_decode_image_chunk(ctx, lctx, chunk, embd, n_past, seq_id, n_batch, new_n_past);
if (ret != 0) {
LOG_ERR("failed to decode %s\n", name);
llama_batch_free(text_batch);
return ret;
}
} else {
GGML_ABORT("chunk type not supported");
}

llama_batch_free(text_batch);
return 0;
}

Expand Down
13 changes: 12 additions & 1 deletion tools/mtmd/mtmd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,15 @@

#include "llama.h"

// fix problem with std::min and std::max
#if defined(_WIN32)
#define WIN32_LEAN_AND_MEAN
#ifndef NOMINMAX
# define NOMINMAX
#endif
#include <windows.h>
#endif

#include <algorithm>
#include <cerrno>
#include <cstdio>
Expand Down Expand Up @@ -1031,7 +1040,9 @@ const char * mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens) {

llama_pos mtmd_image_tokens_get_n_pos(const mtmd_image_tokens * image_tokens) {
if (image_tokens->use_mrope_pos) {
return 1; // for M-RoPE, the whole image is 1 in temporal dimension
// for M-RoPE, temporal dimension = max(t,h,w)
// t is omitted as we don't support video input
return std::max(image_tokens->nx, image_tokens->ny);
}
return image_tokens->n_tokens();
}
Expand Down
2 changes: 1 addition & 1 deletion tools/mtmd/mtmd.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ MTMD_API size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * i
MTMD_API size_t mtmd_image_tokens_get_nx (const mtmd_image_tokens * image_tokens);
MTMD_API size_t mtmd_image_tokens_get_ny (const mtmd_image_tokens * image_tokens);
MTMD_API const char * mtmd_image_tokens_get_id (const mtmd_image_tokens * image_tokens); // TODO: deprecate
// number of temporal positions (always 1 for M-RoPE, n_tokens otherwise)
// number of temporal positions (equals to max(t,h,w) for M-RoPE; equals to n_tokens otherwise)
MTMD_API llama_pos mtmd_image_tokens_get_n_pos (const mtmd_image_tokens * image_tokens); // TODO: deprecate

// tokenize an input text prompt and a list of bitmaps (images/audio)
Expand Down
Loading