Skip to content

Commit edc53a9

Browse files
committed
improve it
1 parent ef1af68 commit edc53a9

File tree

7 files changed

+89
-35
lines changed

7 files changed

+89
-35
lines changed

include/llama.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ extern "C" {
218218
// - token : the token ids of the input (used when embd is NULL)
219219
// - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
220220
// - pos : the positions of the respective token in the sequence
221+
// (for M-RoPE, first `n_tokens` are linearly increasing, followed by `n_pos_per_embd * n_tokens` positions for RoPE)
221222
// (if set to NULL, the token position will be tracked automatically by llama_encode/llama_decode)
222223
// - seq_id : the sequence to which the respective token belongs
223224
// (if set to NULL, the sequence ID will be assumed to be 0)
@@ -232,7 +233,7 @@ extern "C" {
232233

233234
llama_token * token;
234235
float * embd;
235-
llama_pos * pos; // first `n_tokens` elements are always linearly increasing position for traditional llm
236+
llama_pos * pos;
236237
int32_t * n_seq_id;
237238
llama_seq_id ** seq_id;
238239
int8_t * logits; // TODO: rename this to "output"

src/llama-batch.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -639,7 +639,10 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
639639

640640
auto udata = std::make_shared<llama_ubatch::data_t>();
641641

642-
const int32_t n_pos_cur = batch.embd ? (n_pos_per_embd + 1) : 1;
642+
const int32_t n_pos_per_embd_inp = n_pos_per_embd > 1
643+
? (n_pos_per_embd + 1) // include the extra linearly increasing positions for M-RoPE
644+
: 1; // standard RoPE
645+
const int32_t n_pos_cur = batch.embd ? n_pos_per_embd_inp : 1;
643646

644647
const int64_t n_embd_all = batch.embd ? (int64_t) n_tokens*n_embd : 0;
645648
const int64_t n_pos_all = (int64_t) n_tokens*n_pos_cur;

src/llama-graph.cpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,10 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
5454
}
5555
ggml_backend_tensor_set(pos, pos_data.data(), 0, pos_data.size()*ggml_element_size(pos));
5656
} else {
57-
llama_pos * pos_ptr = ubatch->pos;
58-
// Normally, ubatch->pos stores linearly increasing position
59-
// However, some multi-modal models requires special position embedding (e.g. M-Rope in qwen2vl and qwen2.5vl)
60-
// But linearly increasing position is still needed for proper causal attention masking
61-
// So we store both of them: the first n_tokens elements are not changed, while model-specific positions are appended after that.
62-
if (ubatch->embd && n_pos_per_embd > 1) pos_ptr += n_tokens; // use mrope positions
63-
ggml_backend_tensor_set(pos, pos_ptr, 0, n_tokens * n_pos_per_embd * ggml_element_size(pos));
57+
const bool has_mrope = ubatch->embd && n_pos_per_embd > 1;
58+
ggml_backend_tensor_set(pos,
59+
ubatch->pos + (has_mrope ? n_tokens : 0), // skip the first n_tokens positions for M-RoPE
60+
0, n_tokens*n_pos_per_embd*ggml_element_size(pos));
6461
}
6562
}
6663
}

tools/mtmd/mtmd-cli.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,8 @@ static int generate_response(mtmd_cli_context & ctx, int n_predict) {
191191

192192
// eval the token
193193
common_batch_clear(ctx.batch);
194-
common_batch_add(ctx.batch, token_id, ctx.n_past++, {0}, true);
194+
int max_pos = llama_memory_seq_pos_max(llama_get_memory(ctx.lctx), 0);
195+
common_batch_add(ctx.batch, token_id, max_pos+1, {0}, true);
195196
if (llama_decode(ctx.lctx, ctx.batch)) {
196197
LOG_ERR("failed to decode token\n");
197198
return 1;

tools/mtmd/mtmd-helper.cpp

Lines changed: 62 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -55,15 +55,11 @@ llama_pos mtmd_helper_get_n_pos(const mtmd_input_chunks * chunks) {
5555

5656
// helper struct to make working with embd batch easier
5757
// note: this will be removed after llama_batch_ext refactoring
58-
// notes2: Normally, batch's `pos` stores linearly increasing position
59-
// However, some multi-modal models requires special position embedding (e.g. M-Rope in qwen2vl and qwen2.5vl)
60-
// But linearly increasing position is still needed for proper causal attention masking
61-
// So we store both of them: the first n_tokens elements are not changed, while model-specific positions are appended after that.
62-
// So `pos` has `n_tokens * (n_pos_per_embd + 1)` elements
6358
struct decode_embd_batch {
6459
int n_pos_per_embd;
6560
int n_mmproj_embd;
66-
std::vector<llama_pos> pos;
61+
std::vector<llama_pos> pos; // for M-RoPE, this will have (1+n_pos_per_embd)*n_tokens elements
62+
// the extra n_tokens are for linearly increasing positions
6763
std::vector<llama_pos> pos_view; // used by mrope
6864
std::vector<int32_t> n_seq_id;
6965
std::vector<llama_seq_id> seq_id_0;
@@ -171,6 +167,59 @@ struct decode_embd_batch {
171167
}
172168
};
173169

170+
// helper struct to make working with embd batch easier
171+
struct decode_text_batch {
172+
std::vector<llama_token> tokens;
173+
std::vector<int32_t> n_seq_id;
174+
std::vector<llama_seq_id> seq_id_0;
175+
std::vector<llama_seq_id *> seq_ids;
176+
std::vector<int8_t> logits;
177+
llama_seq_id seq_id;
178+
llama_batch batch;
179+
decode_text_batch(int32_t n_tokens, llama_seq_id seq_id) : seq_id(seq_id) {
180+
tokens .resize(n_tokens);
181+
n_seq_id.resize(n_tokens);
182+
seq_ids .resize(n_tokens + 1);
183+
logits .resize(n_tokens);
184+
seq_ids[n_tokens] = nullptr;
185+
for (int32_t i = 0; i < n_tokens; i++) {
186+
n_seq_id[i] = 1;
187+
seq_ids [i] = &this->seq_id;
188+
}
189+
batch = {
190+
/*n_tokens =*/ n_tokens,
191+
/*tokens =*/ tokens.data(),
192+
/*embd =*/ nullptr,
193+
/*pos =*/ nullptr, // position is tracked automatically
194+
/*n_seq_id =*/ n_seq_id.data(),
195+
/*seq_id =*/ seq_ids.data(),
196+
/*logits =*/ logits.data(),
197+
};
198+
}
199+
200+
void clear() {
201+
batch.n_tokens = 0;
202+
}
203+
204+
bool is_full() const {
205+
return batch.n_tokens >= (int32_t) tokens.size();
206+
}
207+
208+
void add_token(llama_token tok, bool output) {
209+
GGML_ASSERT(!is_full());
210+
int32_t j = batch.n_tokens;
211+
batch.token [j] = tok;
212+
batch.logits[j] = output;
213+
batch.n_tokens++;
214+
}
215+
216+
void set_logits_last() {
217+
if (batch.n_tokens > 0) {
218+
batch.logits[batch.n_tokens - 1] = true;
219+
}
220+
}
221+
};
222+
174223
// Helper function for decoding an image whose embeddings have already been calculated
175224
int32_t mtmd_helper_decode_image_chunk(
176225
mtmd_context * ctx,
@@ -259,7 +308,7 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
259308
bool logits_last,
260309
llama_pos * new_n_past) {
261310
int32_t ret;
262-
llama_batch text_batch = llama_batch_init(n_batch, 0, 1);
311+
decode_text_batch text_batch(n_batch, seq_id);
263312
auto chunk_type = mtmd_input_chunk_get_type(chunk);
264313

265314
if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
@@ -268,28 +317,20 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
268317
// LOG_INF("decoding text chunk, n_tokens = %zu\n", n_tokens);
269318
size_t i = 0;
270319
while (i < n_tokens) { // split into batches
271-
text_batch.n_tokens = 0; // clear the batch
272-
for (; i < n_tokens && text_batch.n_tokens < n_batch; i++) {
273-
int32_t j = text_batch.n_tokens;
274-
text_batch.token [j] = tokens[i];
275-
text_batch.pos [j] = n_past++;
276-
text_batch.n_seq_id[j] = 1;
277-
text_batch.seq_id [j][0] = seq_id;
278-
text_batch.logits [j] = false;
279-
280-
text_batch.n_tokens++;
320+
text_batch.clear();
321+
for (; i < n_tokens && !text_batch.is_full(); i++) {
322+
text_batch.add_token(tokens[i], false);
281323
}
282324
bool is_last_token = (i == n_tokens);
283325
if (logits_last && is_last_token) {
284-
text_batch.logits[text_batch.n_tokens - 1] = true;
326+
text_batch.set_logits_last();
285327
}
286-
ret = llama_decode(lctx, text_batch);
328+
ret = llama_decode(lctx, text_batch.batch);
287329
if (ret != 0) {
288330
LOG_ERR("failed to decode text\n");
289-
llama_batch_free(text_batch);
290331
return ret;
291332
}
292-
*new_n_past += text_batch.n_tokens;
333+
*new_n_past += text_batch.batch.n_tokens;
293334
}
294335

295336
} else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE || chunk_type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
@@ -301,7 +342,6 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
301342
ret = mtmd_encode_chunk(ctx, chunk);
302343
if (ret != 0) {
303344
LOG_ERR("failed to encode %s slice\n", name);
304-
llama_batch_free(text_batch);
305345
return ret;
306346
}
307347

@@ -311,14 +351,12 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
311351
ret = mtmd_helper_decode_image_chunk(ctx, lctx, chunk, embd, n_past, seq_id, n_batch, new_n_past);
312352
if (ret != 0) {
313353
LOG_ERR("failed to decode %s\n", name);
314-
llama_batch_free(text_batch);
315354
return ret;
316355
}
317356
} else {
318357
GGML_ABORT("chunk type not supported");
319358
}
320359

321-
llama_batch_free(text_batch);
322360
return 0;
323361
}
324362

tools/mtmd/mtmd.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,15 @@
55

66
#include "llama.h"
77

8+
// fix problem with std::min and std::max
9+
#if defined(_WIN32)
10+
#define WIN32_LEAN_AND_MEAN
11+
#ifndef NOMINMAX
12+
# define NOMINMAX
13+
#endif
14+
#include <windows.h>
15+
#endif
16+
817
#include <algorithm>
918
#include <cerrno>
1019
#include <cstdio>
@@ -1030,6 +1039,11 @@ const char * mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens) {
10301039
}
10311040

10321041
llama_pos mtmd_image_tokens_get_n_pos(const mtmd_image_tokens * image_tokens) {
1042+
if (image_tokens->use_mrope_pos) {
1043+
// for M-RoPE, temporal dimension = max(t,h,w)
1044+
// t is omitted as we don't support video input
1045+
return std::max(image_tokens->nx, image_tokens->ny);
1046+
}
10331047
return image_tokens->n_tokens();
10341048
}
10351049

tools/mtmd/mtmd.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ MTMD_API size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * i
171171
MTMD_API size_t mtmd_image_tokens_get_nx (const mtmd_image_tokens * image_tokens);
172172
MTMD_API size_t mtmd_image_tokens_get_ny (const mtmd_image_tokens * image_tokens);
173173
MTMD_API const char * mtmd_image_tokens_get_id (const mtmd_image_tokens * image_tokens); // TODO: deprecate
174-
// number of temporal positions (always 1 for M-RoPE, n_tokens otherwise)
174+
// number of temporal positions (equals to max(t,h,w) for M-RoPE; equals to n_tokens otherwise)
175175
MTMD_API llama_pos mtmd_image_tokens_get_n_pos (const mtmd_image_tokens * image_tokens); // TODO: deprecate
176176

177177
// tokenize an input text prompt and a list of bitmaps (images/audio)

0 commit comments

Comments
 (0)