Skip to content

Commit b303584

Browse files
committed
working version
1 parent 513e9c9 commit b303584

File tree

4 files changed

+75
-46
lines changed

4 files changed

+75
-46
lines changed

examples/llava/mtmd-cli.cpp

Lines changed: 2 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -136,39 +136,6 @@ struct mtmd_cli_context {
136136
}
137137
};
138138

139-
struct decode_embd_batch {
140-
std::vector<llama_pos> pos;
141-
std::vector<int32_t> n_seq_id;
142-
std::vector<llama_seq_id> seq_id_0;
143-
std::vector<llama_seq_id *> seq_ids;
144-
std::vector<int8_t> logits;
145-
llama_batch batch;
146-
decode_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
147-
pos .resize(n_tokens);
148-
n_seq_id.resize(n_tokens);
149-
seq_ids .resize(n_tokens + 1);
150-
logits .resize(n_tokens);
151-
seq_id_0.resize(1);
152-
seq_id_0[0] = seq_id;
153-
seq_ids [n_tokens] = nullptr;
154-
batch = {
155-
/*n_tokens =*/ n_tokens,
156-
/*tokens =*/ nullptr,
157-
/*embd =*/ embd,
158-
/*pos =*/ pos.data(),
159-
/*n_seq_id =*/ n_seq_id.data(),
160-
/*seq_id =*/ seq_ids.data(),
161-
/*logits =*/ logits.data(),
162-
};
163-
for (int i = 0; i < n_tokens; i++) {
164-
batch.pos [i] = pos_0 + i;
165-
batch.n_seq_id[i] = 1;
166-
batch.seq_id [i] = seq_id_0.data();
167-
batch.logits [i] = false;
168-
}
169-
}
170-
};
171-
172139
static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int n_predict) {
173140
llama_tokens generated_tokens;
174141
for (int i = 0; i < n_predict; i++) {
@@ -243,7 +210,7 @@ static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, std::vect
243210
return 1;
244211
}
245212

246-
ctx.n_past += mtmd_helper_get_n_tokens(chunks);
213+
ctx.n_past += mtmd_helper_get_n_pos(chunks);
247214

248215
return 0;
249216
}
@@ -371,6 +338,7 @@ int main(int argc, char ** argv) {
371338
}
372339
}
373340
if (g_is_interrupted) LOG("\nInterrupted by user\n");
341+
LOG("\n\n");
374342
llama_perf_context_print(ctx.lctx);
375343
return g_is_interrupted ? 130 : 0;
376344
}

examples/llava/mtmd.cpp

Lines changed: 66 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ struct mtmd_image_tokens_data {
128128
struct mtmd_image_tokens {
129129
uint32_t nx; // number of tokens in x direction
130130
uint32_t ny; // number of tokens in y direction
131+
bool use_mrope_pos = false; // use M-RoPE position counting (the whole image is 1 temporal position)
131132
uint32_t n_tokens() const { return nx * ny; }
132133
clip_image_f32_batch batch_f32; // preprocessed image patches
133134
std::string id; // optional user-defined ID, useful for KV cache tracking
@@ -342,6 +343,7 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
342343
// for Qwen2VL, we need this information for M-RoPE decoding positions
343344
image_tokens->nx = clip_n_output_tokens_x(ctx->ctx_clip, batch_f32.entries[0].get());
344345
image_tokens->ny = clip_n_output_tokens_y(ctx->ctx_clip, batch_f32.entries[0].get());
346+
image_tokens->use_mrope_pos = true;
345347
} else {
346348
// other models, we only need the total number of tokens
347349
image_tokens->nx = n_tokens;
@@ -396,6 +398,13 @@ std::string mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens) {
396398
return image_tokens->id;
397399
}
398400

401+
llama_pos mtmd_image_tokens_get_n_pos(const mtmd_image_tokens * image_tokens) {
402+
if (image_tokens->use_mrope_pos) {
403+
return 1; // for M-RoPE, the whole image is 1 in temporal dimension
404+
}
405+
return image_tokens->n_tokens();
406+
}
407+
399408
int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) {
400409
int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip);
401410
ctx->image_embd_v.resize(image_tokens->n_tokens() * n_mmproj_embd);
@@ -441,20 +450,35 @@ size_t mtmd_helper_get_n_tokens(mtmd_input_chunks & chunks) {
441450
if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
442451
n_tokens += chunk.tokens_text.size();
443452
} else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
444-
n_tokens += chunk.tokens_image->n_tokens();
453+
n_tokens += mtmd_image_tokens_get_n_tokens(chunk.tokens_image.get());
445454
} else {
446455
GGML_ASSERT(false && "chunk type not supported");
447456
}
448457
}
449458
return n_tokens;
450459
}
451460

461+
llama_pos mtmd_helper_get_n_pos(mtmd_input_chunks & chunks) {
462+
llama_pos n_pos = 0;
463+
for (auto & chunk : chunks) {
464+
if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
465+
n_pos += chunk.tokens_text.size();
466+
} else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
467+
n_pos += mtmd_image_tokens_get_n_pos(chunk.tokens_image.get());
468+
} else {
469+
GGML_ASSERT(false && "chunk type not supported");
470+
}
471+
}
472+
return n_pos;
473+
}
474+
452475
// helper struct to make working with embd batch easier
453476
// note: this will be removed after llama_batch_ext refactoring
454477
struct decode_embd_batch {
455478
int n_pos_per_embd;
456479
int n_mmproj_embd;
457480
std::vector<llama_pos> pos;
481+
std::vector<llama_pos> pos_view; // used by mrope
458482
std::vector<int32_t> n_seq_id;
459483
std::vector<llama_seq_id> seq_id_0;
460484
std::vector<llama_seq_id *> seq_ids;
@@ -489,16 +513,46 @@ struct decode_embd_batch {
489513
}
490514

491515
void set_position_mrope(llama_pos pos_0, int nx, int ny, llama_seq_id seq_id) {
516+
GGML_ASSERT(n_pos_per_embd == 4);
492517
seq_id_0[0] = seq_id;
493-
GGML_ABORT("TODO");
518+
for (int y = 0; y < ny; y++) {
519+
for (int x = 0; x < nx; x++) {
520+
int i = y * nx + x;
521+
pos[i ] = pos_0;
522+
pos[i + batch.n_tokens ] = pos_0 + y;
523+
pos[i + batch.n_tokens * 2] = pos_0 + x;
524+
pos[i + batch.n_tokens * 3] = 0; // last pos dim is unused
525+
}
526+
}
527+
for (int i = 0; i < batch.n_tokens; i++) {
528+
batch.n_seq_id[i] = 1;
529+
batch.seq_id [i] = seq_id_0.data();
530+
batch.logits [i] = false;
531+
}
494532
}
495533

496534
llama_batch get_view(int offset, int n_tokens) {
535+
llama_pos * pos_ptr;
536+
pos_view.clear();
537+
pos_view.resize(n_tokens * n_pos_per_embd);
538+
if (n_pos_per_embd > 1) {
539+
// mrope
540+
// for example, with layout of src: 1234...1234...1234...1234...
541+
// offset 2 will give us dst: 34...34...34...34...
542+
for (int i = 0; i < n_pos_per_embd; i++) {
543+
auto src = pos.begin() + i * batch.n_tokens + offset;
544+
pos_view.insert(pos_view.end(), src, src + n_tokens);
545+
}
546+
pos_ptr = pos_view.data();
547+
} else {
548+
// normal
549+
pos_ptr = pos.data() + offset;
550+
}
497551
return {
498552
/*n_tokens =*/ n_tokens,
499553
/*tokens =*/ nullptr,
500554
/*embd =*/ batch.embd + offset * n_mmproj_embd,
501-
/*pos =*/ batch.pos + offset * n_pos_per_embd,
555+
/*pos =*/ pos_ptr,
502556
/*n_seq_id =*/ batch.n_seq_id + offset,
503557
/*seq_id =*/ batch.seq_id + offset,
504558
/*logits =*/ batch.logits + offset,
@@ -566,12 +620,13 @@ int32_t mtmd_helper_eval(mtmd_context * ctx,
566620
float * embd = mtmd_get_output_embd(ctx);
567621
decode_embd_batch batch_embd(embd, n_tokens, n_pos_per_embd, n_mmproj_embd);
568622

623+
const int nx = mtmd_image_tokens_get_nx(chunk.tokens_image.get());
624+
const int ny = mtmd_image_tokens_get_ny(chunk.tokens_image.get());
625+
569626
if (mtmd_decode_use_mrope(ctx)) {
570-
int nx = mtmd_image_tokens_get_nx(chunk.tokens_image.get());
571-
int ny = mtmd_image_tokens_get_ny(chunk.tokens_image.get());
572-
batch_embd.set_position_mrope(pos0, nx, ny, seq_id);
627+
batch_embd.set_position_mrope(n_past, nx, ny, seq_id);
573628
} else {
574-
batch_embd.set_position_normal(pos0, seq_id);
629+
batch_embd.set_position_normal(n_past, seq_id);
575630
}
576631

577632
if (mtmd_decode_use_non_causal(ctx)) {
@@ -584,7 +639,7 @@ int32_t mtmd_helper_eval(mtmd_context * ctx,
584639
int n_tokens_batch = std::min(n_batch, n_tokens - pos_offset);
585640
llama_batch batch_embd_view = batch_embd.get_view(pos_offset, n_tokens_batch);
586641

587-
printf("decoding image batch %d/%d, n_tokens_batch = %d\n", i_batch+1, n_img_batches, n_tokens_batch);
642+
LOG_INF("decoding image batch %d/%d, n_tokens_batch = %d\n", i_batch+1, n_img_batches, n_tokens_batch);
588643

589644
int64_t t1 = ggml_time_ms();
590645
ret = llama_decode(lctx, batch_embd_view);
@@ -600,9 +655,11 @@ int32_t mtmd_helper_eval(mtmd_context * ctx,
600655
}
601656

602657
i_batch++;
603-
n_past += n_tokens_batch;
604658
}
605659

660+
// for mrope, one image is one single **temporal** position
661+
n_past += mtmd_decode_use_mrope(ctx) ? 1 : n_tokens;
662+
606663
if (mtmd_decode_use_non_causal(ctx)) {
607664
llama_set_causal_attn(lctx, true);
608665
}

examples/llava/mtmd.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ MTMD_API size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * im
102102
MTMD_API size_t mtmd_image_tokens_get_nx(const mtmd_image_tokens * image_tokens);
103103
MTMD_API size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens);
104104
MTMD_API std::string mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens);
105+
MTMD_API llama_pos mtmd_image_tokens_get_n_pos(const mtmd_image_tokens * image_tokens); // number of temporal positions (always 1 for M-RoPE, n_tokens otherwise)
105106
MTMD_API void mtmd_image_tokens_free(mtmd_image_tokens * image_tokens);
106107

107108
// returns 0 on success
@@ -123,9 +124,12 @@ MTMD_API bool mtmd_decode_use_mrope(mtmd_context * ctx);
123124
// helper functions (can be implemented based on other functions)
124125
//
125126

126-
// helper to count the total number of tokens from a list of chunks, useful to keep track of n_past
127+
// helper to count the total number of tokens from a list of chunks, useful to keep track of KV cache
127128
MTMD_API size_t mtmd_helper_get_n_tokens(mtmd_input_chunks & chunks);
128129

130+
// helper to count the total position of tokens from a list of chunks, useful to keep track of n_past
131+
MTMD_API llama_pos mtmd_helper_get_n_pos(mtmd_input_chunks & chunks);
132+
129133
// helper function that automatically:
130134
// 1. run llama_decode() on text chunks
131135
// 2. run mtmd_encode() on image chunks, then mtmd_get_output_embd() and then llama_decode()

examples/llava/tests.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ add_test "llama-mtmd-cli" "ibm-research/granite-vision-3.2-2b-GGUF:Q4_K_M"
5454
add_test "llama-mtmd-cli" "second-state/MiniCPM-Llama3-V-2_5-GGUF:Q2_K" # model from openbmb is corrupted
5555
add_test "llama-mtmd-cli" "openbmb/MiniCPM-V-2_6-gguf:Q2_K"
5656
add_test "llama-mtmd-cli" "openbmb/MiniCPM-o-2_6-gguf:Q4_0"
57-
add_test "llama-qwen2vl-cli" "bartowski/Qwen2-VL-2B-Instruct-GGUF:Q4_K_M"
58-
add_test "llama-qwen2vl-cli" "ggml-org/Qwen2.5-VL-3B-Instruct-GGUF:Q4_K_M"
57+
add_test "llama-mtmd-cli" "bartowski/Qwen2-VL-2B-Instruct-GGUF:Q4_K_M"
58+
add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-3B-Instruct-GGUF:Q4_K_M"
5959

6060
# to test the big models, run: ./tests.sh big
6161
add_test_big "llama-mtmd-cli" "ggml-org/pixtral-12b-GGUF:Q4_K_M"

0 commit comments

Comments
 (0)