Skip to content

Commit 8742f8a

Browse files
committed
mtmd : add qwen2vl and qwen2.5vl
1 parent 6f7a55c commit 8742f8a

File tree

2 files changed

+36
-7
lines changed

2 files changed

+36
-7
lines changed

examples/llava/mtmd.cpp

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,13 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
204204
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
205205
}
206206

207+
else if (proj_type == PROJECTOR_TYPE_QWEN2VL || proj_type == PROJECTOR_TYPE_QWEN25VL) {
208+
// <|vision_start|> ... (image embeddings) ... <|vision_end|>
209+
marker_modified = "<|vision_start|>" + ctx->image_marker + "<|vision_end|>";
210+
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
211+
212+
}
213+
207214
// llava-1.5, llava-1.6, Yi-VL, Yi-34B, granite: don't need to add prefix and suffix
208215
// for glm-edge, we don't need to add because the tokens are already in the returned embeddings
209216

@@ -445,14 +452,16 @@ size_t mtmd_helper_get_n_tokens(mtmd_input_chunks & chunks) {
445452
// helper struct to make working with embd batch easier
446453
// note: this will be removed after llama_batch_ext refactoring
447454
struct decode_embd_batch {
455+
int n_pos_per_embd;
456+
int n_mmproj_embd;
448457
std::vector<llama_pos> pos;
449458
std::vector<int32_t> n_seq_id;
450459
std::vector<llama_seq_id> seq_id_0;
451460
std::vector<llama_seq_id *> seq_ids;
452461
std::vector<int8_t> logits;
453462
llama_batch batch;
454-
decode_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
455-
pos .resize(n_tokens);
463+
decode_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id, int n_pos_per_embd, int n_mmproj_embd) : n_pos_per_embd(n_pos_per_embd), n_mmproj_embd(n_mmproj_embd) {
464+
pos .resize(n_tokens * n_pos_per_embd);
456465
n_seq_id.resize(n_tokens);
457466
seq_ids .resize(n_tokens + 1);
458467
logits .resize(n_tokens);
@@ -475,6 +484,18 @@ struct decode_embd_batch {
475484
batch.logits [i] = false;
476485
}
477486
}
487+
488+
llama_batch get_view(int offset, int n_tokens) {
489+
return {
490+
/*n_tokens =*/ n_tokens,
491+
/*tokens =*/ nullptr,
492+
/*embd =*/ batch.embd + offset * n_mmproj_embd,
493+
/*pos =*/ batch.pos + offset * n_pos_per_embd,
494+
/*n_seq_id =*/ batch.n_seq_id + offset,
495+
/*seq_id =*/ batch.seq_id + offset,
496+
/*logits =*/ batch.logits + offset,
497+
};
498+
}
478499
};
479500

480501
int32_t mtmd_helper_eval(mtmd_context * ctx,
@@ -487,6 +508,7 @@ int32_t mtmd_helper_eval(mtmd_context * ctx,
487508
llama_pos n_past = pos0;
488509
llama_batch text_batch = llama_batch_init(n_batch, 0, 1);
489510
int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip);
511+
int n_pos_per_embd = mtmd_decode_use_mrope(ctx) ? 4 : 1;
490512

491513
for (auto & chunk : chunks) {
492514
bool is_last = &chunk == &chunks.back();
@@ -534,22 +556,22 @@ int32_t mtmd_helper_eval(mtmd_context * ctx,
534556
int32_t i_batch = 0;
535557
int32_t n_img_batches = GGML_PAD(n_tokens, n_batch) / n_batch;
536558
float * embd = mtmd_get_output_embd(ctx);
559+
decode_embd_batch batch_embd(embd, n_tokens, n_past, seq_id, n_pos_per_embd, n_mmproj_embd);
537560

538561
if (mtmd_decode_use_non_causal(ctx)) {
539562
llama_set_causal_attn(lctx, false);
540563
// TODO @ngxson : need to make sure only one image is processed at a time, and n_ubatch must be enough to hold the image
541564
}
542565

543566
while (i_batch < n_img_batches) { // split into batches
544-
int32_t pos_offset = i_batch*n_batch;
545-
int32_t n_tokens_batch = std::min(n_batch, n_tokens - pos_offset);
546-
float * embd_batch = embd + pos_offset*n_mmproj_embd;
547-
decode_embd_batch batch_img(embd_batch, n_tokens_batch, n_past, 0);
567+
int pos_offset = i_batch*n_batch;
568+
int n_tokens_batch = std::min(n_batch, n_tokens - pos_offset);
569+
llama_batch batch_embd_view = batch_embd.get_view(pos_offset, n_tokens_batch);
548570

549571
printf("decoding image batch %d/%d, n_tokens_batch = %d\n", i_batch+1, n_img_batches, n_tokens_batch);
550572

551573
int64_t t1 = ggml_time_ms();
552-
ret = llama_decode(lctx, batch_img.batch);
574+
ret = llama_decode(lctx, batch_embd_view);
553575
if (ret != 0) {
554576
LOG_ERR("failed to decode image\n");
555577
llama_set_causal_attn(lctx, true); // restore causal attn
@@ -612,6 +634,10 @@ bool mtmd_decode_use_non_causal(mtmd_context * ctx) {
612634
return false;
613635
}
614636

637+
bool mtmd_decode_use_mrope(mtmd_context * ctx) {
638+
return ctx->use_mrope;
639+
}
640+
615641
void mtmd_image_tokens_deleter::operator()(mtmd_image_tokens * val) {
616642
mtmd_image_tokens_free(val);
617643
}

examples/llava/mtmd.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,9 @@ MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx);
114114
// whether we need to set non-causal mask before llama_decode
115115
MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx);
116116

117+
// whether the current model use M-RoPE for llama_decode
118+
MTMD_API bool mtmd_decode_use_mrope(mtmd_context * ctx);
119+
117120

118121

119122
//

0 commit comments

Comments
 (0)