@@ -167,7 +167,7 @@ struct mtmd_image_tokens {
167167 clip_image_f32_batch batch_f32; // preprocessed image patches
168168 std::string id; // optional user-defined ID, useful for KV cache tracking
169169
170- mtmd_image_tokens clone () {
170+ mtmd_image_tokens clone () const {
171171 return mtmd_image_tokens{
172172 nx,
173173 ny,
@@ -409,12 +409,6 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
409409 return 0 ;
410410}
411411
412- static void mtmd_image_tokens_free (mtmd_image_tokens * image_tokens) {
413- if (image_tokens) {
414- delete image_tokens;
415- }
416- }
417-
418412int32_t mtmd_encode (mtmd_context * ctx, const mtmd_image_tokens * image_tokens) {
419413 int n_mmproj_embd = clip_n_mmproj_embd (ctx->ctx_clip );
420414 ctx->image_embd_v .resize (image_tokens->n_tokens () * n_mmproj_embd);
@@ -454,6 +448,23 @@ float * mtmd_get_output_embd(mtmd_context * ctx) {
454448 return ctx->image_embd_v .data ();
455449}
456450
451+ float * mtmd_get_output_embd_copy (mtmd_context * ctx, size_t * n_embd_out) {
452+ if (ctx->image_embd_v .empty ()) {
453+ *n_embd_out = 0 ;
454+ return NULL ;
455+ }
456+
457+ *n_embd_out = ctx->image_embd_v .size ();
458+ float * copy = (float *) malloc (*n_embd_out * sizeof (float ));
459+ if (copy == NULL ) {
460+ *n_embd_out = 0 ;
461+ return NULL ;
462+ }
463+
464+ memcpy (copy, ctx->image_embd_v .data (), ctx->image_embd_v .size () * sizeof (float ));
465+ return copy;
466+ }
467+
457468size_t mtmd_helper_get_n_tokens (const mtmd_input_chunks * chunks) {
458469 size_t n_tokens = 0 ;
459470 for (size_t i = 0 ; i < mtmd_input_chunks_size (chunks); i++) {
@@ -580,6 +591,69 @@ struct decode_embd_batch {
580591 }
581592};
582593
594+ // Helper function for decoding an image whose embeddings have already been calculated
595+ int32_t mtmd_helper_decode_image (
596+ mtmd_context * ctx,
597+ struct llama_context * lctx,
598+ const mtmd_image_tokens * image_tokens,
599+ float * embd,
600+ llama_pos n_past,
601+ llama_seq_id seq_id,
602+ int32_t n_batch,
603+ llama_pos * new_n_past) {
604+ int n_mmproj_embd = clip_n_mmproj_embd (ctx->ctx_clip );
605+ int n_pos_per_embd = mtmd_decode_use_mrope (ctx) ? 4 : 1 ;
606+
607+ int32_t n_tokens = mtmd_image_tokens_get_n_tokens (image_tokens);
608+ int32_t i_batch = 0 ;
609+ int32_t n_img_batches = GGML_PAD (n_tokens, n_batch) / n_batch;
610+ decode_embd_batch batch_embd (embd, n_tokens, n_pos_per_embd, n_mmproj_embd);
611+
612+ const int nx = mtmd_image_tokens_get_nx (image_tokens);
613+ const int ny = mtmd_image_tokens_get_ny (image_tokens);
614+
615+ if (mtmd_decode_use_mrope (ctx)) {
616+ batch_embd.set_position_mrope (n_past, nx, ny, seq_id);
617+ } else {
618+ batch_embd.set_position_normal (n_past, seq_id);
619+ }
620+
621+ if (mtmd_decode_use_non_causal (ctx)) {
622+ llama_set_causal_attn (lctx, false );
623+ // TODO @ngxson : need to make sure only one image is processed at a time, and n_ubatch must be enough to hold the image
624+ }
625+
626+ while (i_batch < n_img_batches) { // split into batches
627+ int pos_offset = i_batch*n_batch;
628+ int n_tokens_batch = std::min (n_batch, n_tokens - pos_offset);
629+ llama_batch batch_embd_view = batch_embd.get_view (pos_offset, n_tokens_batch);
630+
631+ LOG_INF (" decoding image batch %d/%d, n_tokens_batch = %d\n " , i_batch+1 , n_img_batches, n_tokens_batch);
632+
633+ int64_t t1 = ggml_time_ms ();
634+ int32_t ret = llama_decode (lctx, batch_embd_view);
635+ if (ret != 0 ) {
636+ LOG_ERR (" failed to decode image\n " );
637+ llama_set_causal_attn (lctx, true ); // restore causal attn
638+ return ret;
639+ }
640+
641+ if (ctx->print_timings ) {
642+ LOG_INF (" image decoded (batch %d/%d) in %" PRId64 " ms\n " , i_batch+1 , n_img_batches, ggml_time_ms () - t1);
643+ }
644+
645+ i_batch++;
646+ }
647+
648+ n_past += mtmd_image_tokens_get_n_pos (image_tokens);
649+ *new_n_past = n_past;
650+
651+ if (mtmd_decode_use_non_causal (ctx)) {
652+ llama_set_causal_attn (lctx, true );
653+ }
654+ return 0 ;
655+ }
656+
583657int32_t mtmd_helper_eval_chunk_single (mtmd_context * ctx,
584658 struct llama_context * lctx,
585659 const mtmd_input_chunk * chunk,
@@ -591,8 +665,6 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
591665 int32_t ret;
592666 llama_batch text_batch = llama_batch_init (n_batch, 0 , 1 );
593667 auto chunk_type = mtmd_input_chunk_get_type (chunk);
594- int n_mmproj_embd = clip_n_mmproj_embd (ctx->ctx_clip );
595- int n_pos_per_embd = mtmd_decode_use_mrope (ctx) ? 4 : 1 ;
596668
597669 if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
598670 size_t n_tokens;
@@ -637,57 +709,13 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
637709 if (ctx->print_timings ) {
638710 LOG_INF (" image/slice encoded in %" PRId64 " ms\n " , ggml_time_ms () - t0);
639711 }
640-
641- int32_t n_tokens = mtmd_image_tokens_get_n_tokens (image_tokens);
642- int32_t i_batch = 0 ;
643- int32_t n_img_batches = GGML_PAD (n_tokens, n_batch) / n_batch;
644712 float * embd = mtmd_get_output_embd (ctx);
645- decode_embd_batch batch_embd (embd, n_tokens, n_pos_per_embd, n_mmproj_embd);
646-
647- const int nx = mtmd_image_tokens_get_nx (image_tokens);
648- const int ny = mtmd_image_tokens_get_ny (image_tokens);
649-
650- if (mtmd_decode_use_mrope (ctx)) {
651- batch_embd.set_position_mrope (n_past, nx, ny, seq_id);
652- } else {
653- batch_embd.set_position_normal (n_past, seq_id);
654- }
655-
656- if (mtmd_decode_use_non_causal (ctx)) {
657- llama_set_causal_attn (lctx, false );
658- // TODO @ngxson : need to make sure only one image is processed at a time, and n_ubatch must be enough to hold the image
659- }
660-
661- while (i_batch < n_img_batches) { // split into batches
662- int pos_offset = i_batch*n_batch;
663- int n_tokens_batch = std::min (n_batch, n_tokens - pos_offset);
664- llama_batch batch_embd_view = batch_embd.get_view (pos_offset, n_tokens_batch);
665-
666- LOG_INF (" decoding image batch %d/%d, n_tokens_batch = %d\n " , i_batch+1 , n_img_batches, n_tokens_batch);
667-
668- int64_t t1 = ggml_time_ms ();
669- ret = llama_decode (lctx, batch_embd_view);
670- if (ret != 0 ) {
671- LOG_ERR (" failed to decode image\n " );
672- llama_set_causal_attn (lctx, true ); // restore causal attn
673- llama_batch_free (text_batch);
674- return ret;
675- }
676-
677- if (ctx->print_timings ) {
678- LOG_INF (" image decoded (batch %d/%d) in %" PRId64 " ms\n " , i_batch+1 , n_img_batches, ggml_time_ms () - t1);
679- }
680-
681- i_batch++;
682- }
683-
684- n_past += mtmd_image_tokens_get_n_pos (image_tokens);
685- *new_n_past = n_past;
686-
687- if (mtmd_decode_use_non_causal (ctx)) {
688- llama_set_causal_attn (lctx, true );
713+ ret = mtmd_helper_decode_image (ctx, lctx, image_tokens, embd, n_past, seq_id, n_batch, new_n_past);
714+ if (ret != 0 ) {
715+ LOG_ERR (" failed to decode image\n " );
716+ llama_batch_free (text_batch);
717+ return ret;
689718 }
690-
691719 } else {
692720 GGML_ABORT (" chunk type not supported" );
693721 }
@@ -903,6 +931,19 @@ llama_pos mtmd_image_tokens_get_n_pos(const mtmd_image_tokens * image_tokens) {
903931 return image_tokens->n_tokens ();
904932}
905933
934+ void mtmd_image_tokens_free (mtmd_image_tokens * image_tokens) {
935+ if (image_tokens) {
936+ delete image_tokens;
937+ }
938+ }
939+
940+ mtmd_image_tokens * mtmd_image_tokens_copy (const mtmd_image_tokens * image_tokens) {
941+ if (!image_tokens) {
942+ return nullptr ;
943+ }
944+ return new mtmd_image_tokens (image_tokens->clone ());
945+ }
946+
906947// test function
907948
908949mtmd_input_chunks * mtmd_test_create_input_chunks () {
0 commit comments