@@ -204,6 +204,13 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
204204 string_replace_all (prompt_modified, ctx->image_marker , marker_modified);
205205 }
206206
207+ else if (proj_type == PROJECTOR_TYPE_QWEN2VL || proj_type == PROJECTOR_TYPE_QWEN25VL) {
208+ // <|vision_start|> ... (image embeddings) ... <|vision_end|>
209+ marker_modified = " <|vision_start|>" + ctx->image_marker + " <|vision_end|>" ;
210+ string_replace_all (prompt_modified, ctx->image_marker , marker_modified);
211+
212+ }
213+
207214 // llava-1.5, llava-1.6, Yi-VL, Yi-34B, granite: don't need to add prefix and suffix
208215 // for glm-edge, we don't need to add because the tokens are already in the returned embeddings
209216
@@ -445,14 +452,16 @@ size_t mtmd_helper_get_n_tokens(mtmd_input_chunks & chunks) {
445452// helper struct to make working with embd batch easier
446453// note: this will be removed after llama_batch_ext refactoring
447454struct decode_embd_batch {
455+ int n_pos_per_embd;
456+ int n_mmproj_embd;
448457 std::vector<llama_pos> pos;
449458 std::vector<int32_t > n_seq_id;
450459 std::vector<llama_seq_id> seq_id_0;
451460 std::vector<llama_seq_id *> seq_ids;
452461 std::vector<int8_t > logits;
453462 llama_batch batch;
454- decode_embd_batch (float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
455- pos .resize (n_tokens);
463+ decode_embd_batch (float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id, int n_pos_per_embd, int n_mmproj_embd) : n_pos_per_embd(n_pos_per_embd), n_mmproj_embd(n_mmproj_embd ) {
464+ pos .resize (n_tokens * n_pos_per_embd );
456465 n_seq_id.resize (n_tokens);
457466 seq_ids .resize (n_tokens + 1 );
458467 logits .resize (n_tokens);
@@ -475,6 +484,18 @@ struct decode_embd_batch {
475484 batch.logits [i] = false ;
476485 }
477486 }
487+
488+ llama_batch get_view (int offset, int n_tokens) {
489+ return {
490+ /* n_tokens =*/ n_tokens,
491+ /* tokens =*/ nullptr ,
492+ /* embd =*/ batch.embd + offset * n_mmproj_embd,
493+ /* pos =*/ batch.pos + offset * n_pos_per_embd,
494+ /* n_seq_id =*/ batch.n_seq_id + offset,
495+ /* seq_id =*/ batch.seq_id + offset,
496+ /* logits =*/ batch.logits + offset,
497+ };
498+ }
478499};
479500
480501int32_t mtmd_helper_eval (mtmd_context * ctx,
@@ -487,6 +508,7 @@ int32_t mtmd_helper_eval(mtmd_context * ctx,
487508 llama_pos n_past = pos0;
488509 llama_batch text_batch = llama_batch_init (n_batch, 0 , 1 );
489510 int n_mmproj_embd = clip_n_mmproj_embd (ctx->ctx_clip );
511+ int n_pos_per_embd = mtmd_decode_use_mrope (ctx) ? 4 : 1 ;
490512
491513 for (auto & chunk : chunks) {
492514 bool is_last = &chunk == &chunks.back ();
@@ -534,22 +556,22 @@ int32_t mtmd_helper_eval(mtmd_context * ctx,
534556 int32_t i_batch = 0 ;
535557 int32_t n_img_batches = GGML_PAD (n_tokens, n_batch) / n_batch;
536558 float * embd = mtmd_get_output_embd (ctx);
559+ decode_embd_batch batch_embd (embd, n_tokens, n_past, seq_id, n_pos_per_embd, n_mmproj_embd);
537560
538561 if (mtmd_decode_use_non_causal (ctx)) {
539562 llama_set_causal_attn (lctx, false );
540563 // TODO @ngxson : need to make sure only one image is processed at a time, and n_ubatch must be enough to hold the image
541564 }
542565
543566 while (i_batch < n_img_batches) { // split into batches
544- int32_t pos_offset = i_batch*n_batch;
545- int32_t n_tokens_batch = std::min (n_batch, n_tokens - pos_offset);
546- float * embd_batch = embd + pos_offset*n_mmproj_embd;
547- decode_embd_batch batch_img (embd_batch, n_tokens_batch, n_past, 0 );
567+ int pos_offset = i_batch*n_batch;
568+ int n_tokens_batch = std::min (n_batch, n_tokens - pos_offset);
569+ llama_batch batch_embd_view = batch_embd.get_view (pos_offset, n_tokens_batch);
548570
549571 printf (" decoding image batch %d/%d, n_tokens_batch = %d\n " , i_batch+1 , n_img_batches, n_tokens_batch);
550572
551573 int64_t t1 = ggml_time_ms ();
552- ret = llama_decode (lctx, batch_img. batch );
574+ ret = llama_decode (lctx, batch_embd_view );
553575 if (ret != 0 ) {
554576 LOG_ERR (" failed to decode image\n " );
555577 llama_set_causal_attn (lctx, true ); // restore causal attn
@@ -612,6 +634,10 @@ bool mtmd_decode_use_non_causal(mtmd_context * ctx) {
612634 return false ;
613635}
614636
637+ bool mtmd_decode_use_mrope (mtmd_context * ctx) {
638+ return ctx->use_mrope ;
639+ }
640+
615641void mtmd_image_tokens_deleter::operator ()(mtmd_image_tokens * val) {
616642 mtmd_image_tokens_free (val);
617643}
0 commit comments