@@ -128,6 +128,7 @@ struct mtmd_image_tokens_data {
128128struct mtmd_image_tokens {
129129 uint32_t nx; // number of tokens in x direction
130130 uint32_t ny; // number of tokens in y direction
131+ bool use_mrope_pos = false ; // use M-RoPE position counting (the whole image is 1 temporal position)
131132 uint32_t n_tokens () const { return nx * ny; }
132133 clip_image_f32_batch batch_f32; // preprocessed image patches
133134 std::string id; // optional user-defined ID, useful for KV cache tracking
@@ -342,6 +343,7 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
342343 // for Qwen2VL, we need this information for M-RoPE decoding positions
343344 image_tokens->nx = clip_n_output_tokens_x (ctx->ctx_clip , batch_f32.entries [0 ].get ());
344345 image_tokens->ny = clip_n_output_tokens_y (ctx->ctx_clip , batch_f32.entries [0 ].get ());
346+ image_tokens->use_mrope_pos = true ;
345347 } else {
346348 // other models, we only need the total number of tokens
347349 image_tokens->nx = n_tokens;
@@ -396,6 +398,13 @@ std::string mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens) {
396398 return image_tokens->id ;
397399}
398400
401+ llama_pos mtmd_image_tokens_get_n_pos (const mtmd_image_tokens * image_tokens) {
402+ if (image_tokens->use_mrope_pos ) {
403+ return 1 ; // for M-RoPE, the whole image is 1 in temporal dimension
404+ }
405+ return image_tokens->n_tokens ();
406+ }
407+
399408int32_t mtmd_encode (mtmd_context * ctx, const mtmd_image_tokens * image_tokens) {
400409 int n_mmproj_embd = clip_n_mmproj_embd (ctx->ctx_clip );
401410 ctx->image_embd_v .resize (image_tokens->n_tokens () * n_mmproj_embd);
@@ -441,20 +450,35 @@ size_t mtmd_helper_get_n_tokens(mtmd_input_chunks & chunks) {
441450 if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
442451 n_tokens += chunk.tokens_text .size ();
443452 } else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
444- n_tokens += chunk.tokens_image -> n_tokens ( );
453+ n_tokens += mtmd_image_tokens_get_n_tokens ( chunk.tokens_image . get () );
445454 } else {
446455 GGML_ASSERT (false && " chunk type not supported" );
447456 }
448457 }
449458 return n_tokens;
450459}
451460
461+ llama_pos mtmd_helper_get_n_pos (mtmd_input_chunks & chunks) {
462+ llama_pos n_pos = 0 ;
463+ for (auto & chunk : chunks) {
464+ if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
465+ n_pos += chunk.tokens_text .size ();
466+ } else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
467+ n_pos += mtmd_image_tokens_get_n_pos (chunk.tokens_image .get ());
468+ } else {
469+ GGML_ASSERT (false && " chunk type not supported" );
470+ }
471+ }
472+ return n_pos;
473+ }
474+
452475// helper struct to make working with embd batch easier
453476// note: this will be removed after llama_batch_ext refactoring
454477struct decode_embd_batch {
455478 int n_pos_per_embd;
456479 int n_mmproj_embd;
457480 std::vector<llama_pos> pos;
481+ std::vector<llama_pos> pos_view; // used by mrope
458482 std::vector<int32_t > n_seq_id;
459483 std::vector<llama_seq_id> seq_id_0;
460484 std::vector<llama_seq_id *> seq_ids;
@@ -489,16 +513,46 @@ struct decode_embd_batch {
489513 }
490514
491515 void set_position_mrope (llama_pos pos_0, int nx, int ny, llama_seq_id seq_id) {
516+ GGML_ASSERT (n_pos_per_embd == 4 );
492517 seq_id_0[0 ] = seq_id;
493- GGML_ABORT (" TODO" );
518+ for (int y = 0 ; y < ny; y++) {
519+ for (int x = 0 ; x < nx; x++) {
520+ int i = y * nx + x;
521+ pos[i ] = pos_0;
522+ pos[i + batch.n_tokens ] = pos_0 + y;
523+ pos[i + batch.n_tokens * 2 ] = pos_0 + x;
524+ pos[i + batch.n_tokens * 3 ] = 0 ; // last pos dim is unused
525+ }
526+ }
527+ for (int i = 0 ; i < batch.n_tokens ; i++) {
528+ batch.n_seq_id [i] = 1 ;
529+ batch.seq_id [i] = seq_id_0.data ();
530+ batch.logits [i] = false ;
531+ }
494532 }
495533
496534 llama_batch get_view (int offset, int n_tokens) {
535+ llama_pos * pos_ptr;
536+ pos_view.clear ();
537+ pos_view.resize (n_tokens * n_pos_per_embd);
538+ if (n_pos_per_embd > 1 ) {
539+ // mrope
540+ // for example, with layout of src: 1234...1234...1234...1234...
541+ // offset 2 will give us dst: 34...34...34...34...
542+ for (int i = 0 ; i < n_pos_per_embd; i++) {
543+ auto src = pos.begin () + i * batch.n_tokens + offset;
544+ pos_view.insert (pos_view.end (), src, src + n_tokens);
545+ }
546+ pos_ptr = pos_view.data ();
547+ } else {
548+ // normal
549+ pos_ptr = pos.data () + offset;
550+ }
497551 return {
498552 /* n_tokens =*/ n_tokens,
499553 /* tokens =*/ nullptr ,
500554 /* embd =*/ batch.embd + offset * n_mmproj_embd,
501- /* pos =*/ batch. pos + offset * n_pos_per_embd ,
555+ /* pos =*/ pos_ptr ,
502556 /* n_seq_id =*/ batch.n_seq_id + offset,
503557 /* seq_id =*/ batch.seq_id + offset,
504558 /* logits =*/ batch.logits + offset,
@@ -566,12 +620,13 @@ int32_t mtmd_helper_eval(mtmd_context * ctx,
566620 float * embd = mtmd_get_output_embd (ctx);
567621 decode_embd_batch batch_embd (embd, n_tokens, n_pos_per_embd, n_mmproj_embd);
568622
623+ const int nx = mtmd_image_tokens_get_nx (chunk.tokens_image .get ());
624+ const int ny = mtmd_image_tokens_get_ny (chunk.tokens_image .get ());
625+
569626 if (mtmd_decode_use_mrope (ctx)) {
570- int nx = mtmd_image_tokens_get_nx (chunk.tokens_image .get ());
571- int ny = mtmd_image_tokens_get_ny (chunk.tokens_image .get ());
572- batch_embd.set_position_mrope (pos0, nx, ny, seq_id);
627+ batch_embd.set_position_mrope (n_past, nx, ny, seq_id);
573628 } else {
574- batch_embd.set_position_normal (pos0 , seq_id);
629+ batch_embd.set_position_normal (n_past , seq_id);
575630 }
576631
577632 if (mtmd_decode_use_non_causal (ctx)) {
@@ -584,7 +639,7 @@ int32_t mtmd_helper_eval(mtmd_context * ctx,
584639 int n_tokens_batch = std::min (n_batch, n_tokens - pos_offset);
585640 llama_batch batch_embd_view = batch_embd.get_view (pos_offset, n_tokens_batch);
586641
587- printf (" decoding image batch %d/%d, n_tokens_batch = %d\n " , i_batch+1 , n_img_batches, n_tokens_batch);
642+ LOG_INF (" decoding image batch %d/%d, n_tokens_batch = %d\n " , i_batch+1 , n_img_batches, n_tokens_batch);
588643
589644 int64_t t1 = ggml_time_ms ();
590645 ret = llama_decode (lctx, batch_embd_view);
@@ -600,9 +655,11 @@ int32_t mtmd_helper_eval(mtmd_context * ctx,
600655 }
601656
602657 i_batch++;
603- n_past += n_tokens_batch;
604658 }
605659
660+ // for mrope, one image is one single **temporal** position
661+ n_past += mtmd_decode_use_mrope (ctx) ? 1 : n_tokens;
662+
606663 if (mtmd_decode_use_non_causal (ctx)) {
607664 llama_set_causal_attn (lctx, true );
608665 }
0 commit comments