Skip to content

Commit 227e139

Browse files
committed
mtmd: Expose helper_decode_image, output_embd_copy, image_tokens_copy/free
1 parent 814f795 commit 227e139

File tree

2 files changed

+124
-63
lines changed

2 files changed

+124
-63
lines changed

tools/mtmd/mtmd.cpp

Lines changed: 99 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ struct mtmd_image_tokens {
167167
clip_image_f32_batch batch_f32; // preprocessed image patches
168168
std::string id; // optional user-defined ID, useful for KV cache tracking
169169

170-
mtmd_image_tokens clone() {
170+
mtmd_image_tokens clone() const {
171171
return mtmd_image_tokens{
172172
nx,
173173
ny,
@@ -409,12 +409,6 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
409409
return 0;
410410
}
411411

412-
static void mtmd_image_tokens_free(mtmd_image_tokens * image_tokens) {
413-
if (image_tokens) {
414-
delete image_tokens;
415-
}
416-
}
417-
418412
int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) {
419413
int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip);
420414
ctx->image_embd_v.resize(image_tokens->n_tokens() * n_mmproj_embd);
@@ -454,6 +448,23 @@ float * mtmd_get_output_embd(mtmd_context * ctx) {
454448
return ctx->image_embd_v.data();
455449
}
456450

451+
float * mtmd_get_output_embd_copy(mtmd_context * ctx, size_t * n_embd_out) {
452+
if (ctx->image_embd_v.empty()) {
453+
*n_embd_out = 0;
454+
return NULL;
455+
}
456+
457+
*n_embd_out = ctx->image_embd_v.size();
458+
float * copy = (float *) malloc(*n_embd_out * sizeof(float));
459+
if (copy == NULL) {
460+
*n_embd_out = 0;
461+
return NULL;
462+
}
463+
464+
memcpy(copy, ctx->image_embd_v.data(), ctx->image_embd_v.size() * sizeof(float));
465+
return copy;
466+
}
467+
457468
size_t mtmd_helper_get_n_tokens(const mtmd_input_chunks * chunks) {
458469
size_t n_tokens = 0;
459470
for (size_t i = 0; i < mtmd_input_chunks_size(chunks); i++) {
@@ -580,6 +591,69 @@ struct decode_embd_batch {
580591
}
581592
};
582593

594+
// Helper function for decoding an image whose embeddings have already been calculated
595+
int32_t mtmd_helper_decode_image(
596+
mtmd_context * ctx,
597+
struct llama_context * lctx,
598+
const mtmd_image_tokens * image_tokens,
599+
float * embd,
600+
llama_pos n_past,
601+
llama_seq_id seq_id,
602+
int32_t n_batch,
603+
llama_pos * new_n_past) {
604+
int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip);
605+
int n_pos_per_embd = mtmd_decode_use_mrope(ctx) ? 4 : 1;
606+
607+
int32_t n_tokens = mtmd_image_tokens_get_n_tokens(image_tokens);
608+
int32_t i_batch = 0;
609+
int32_t n_img_batches = GGML_PAD(n_tokens, n_batch) / n_batch;
610+
decode_embd_batch batch_embd(embd, n_tokens, n_pos_per_embd, n_mmproj_embd);
611+
612+
const int nx = mtmd_image_tokens_get_nx(image_tokens);
613+
const int ny = mtmd_image_tokens_get_ny(image_tokens);
614+
615+
if (mtmd_decode_use_mrope(ctx)) {
616+
batch_embd.set_position_mrope(n_past, nx, ny, seq_id);
617+
} else {
618+
batch_embd.set_position_normal(n_past, seq_id);
619+
}
620+
621+
if (mtmd_decode_use_non_causal(ctx)) {
622+
llama_set_causal_attn(lctx, false);
623+
// TODO @ngxson : need to make sure only one image is processed at a time, and n_ubatch must be enough to hold the image
624+
}
625+
626+
while (i_batch < n_img_batches) { // split into batches
627+
int pos_offset = i_batch*n_batch;
628+
int n_tokens_batch = std::min(n_batch, n_tokens - pos_offset);
629+
llama_batch batch_embd_view = batch_embd.get_view(pos_offset, n_tokens_batch);
630+
631+
LOG_INF("decoding image batch %d/%d, n_tokens_batch = %d\n", i_batch+1, n_img_batches, n_tokens_batch);
632+
633+
int64_t t1 = ggml_time_ms();
634+
int32_t ret = llama_decode(lctx, batch_embd_view);
635+
if (ret != 0) {
636+
LOG_ERR("failed to decode image\n");
637+
llama_set_causal_attn(lctx, true); // restore causal attn
638+
return ret;
639+
}
640+
641+
if (ctx->print_timings) {
642+
LOG_INF("image decoded (batch %d/%d) in %" PRId64 " ms\n", i_batch+1, n_img_batches, ggml_time_ms() - t1);
643+
}
644+
645+
i_batch++;
646+
}
647+
648+
n_past += mtmd_image_tokens_get_n_pos(image_tokens);
649+
*new_n_past = n_past;
650+
651+
if (mtmd_decode_use_non_causal(ctx)) {
652+
llama_set_causal_attn(lctx, true);
653+
}
654+
return 0;
655+
}
656+
583657
int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
584658
struct llama_context * lctx,
585659
const mtmd_input_chunk * chunk,
@@ -591,8 +665,6 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
591665
int32_t ret;
592666
llama_batch text_batch = llama_batch_init(n_batch, 0, 1);
593667
auto chunk_type = mtmd_input_chunk_get_type(chunk);
594-
int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip);
595-
int n_pos_per_embd = mtmd_decode_use_mrope(ctx) ? 4 : 1;
596668

597669
if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
598670
size_t n_tokens;
@@ -637,57 +709,13 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
637709
if (ctx->print_timings) {
638710
LOG_INF("image/slice encoded in %" PRId64 " ms\n", ggml_time_ms() - t0);
639711
}
640-
641-
int32_t n_tokens = mtmd_image_tokens_get_n_tokens(image_tokens);
642-
int32_t i_batch = 0;
643-
int32_t n_img_batches = GGML_PAD(n_tokens, n_batch) / n_batch;
644712
float * embd = mtmd_get_output_embd(ctx);
645-
decode_embd_batch batch_embd(embd, n_tokens, n_pos_per_embd, n_mmproj_embd);
646-
647-
const int nx = mtmd_image_tokens_get_nx(image_tokens);
648-
const int ny = mtmd_image_tokens_get_ny(image_tokens);
649-
650-
if (mtmd_decode_use_mrope(ctx)) {
651-
batch_embd.set_position_mrope(n_past, nx, ny, seq_id);
652-
} else {
653-
batch_embd.set_position_normal(n_past, seq_id);
654-
}
655-
656-
if (mtmd_decode_use_non_causal(ctx)) {
657-
llama_set_causal_attn(lctx, false);
658-
// TODO @ngxson : need to make sure only one image is processed at a time, and n_ubatch must be enough to hold the image
659-
}
660-
661-
while (i_batch < n_img_batches) { // split into batches
662-
int pos_offset = i_batch*n_batch;
663-
int n_tokens_batch = std::min(n_batch, n_tokens - pos_offset);
664-
llama_batch batch_embd_view = batch_embd.get_view(pos_offset, n_tokens_batch);
665-
666-
LOG_INF("decoding image batch %d/%d, n_tokens_batch = %d\n", i_batch+1, n_img_batches, n_tokens_batch);
667-
668-
int64_t t1 = ggml_time_ms();
669-
ret = llama_decode(lctx, batch_embd_view);
670-
if (ret != 0) {
671-
LOG_ERR("failed to decode image\n");
672-
llama_set_causal_attn(lctx, true); // restore causal attn
673-
llama_batch_free(text_batch);
674-
return ret;
675-
}
676-
677-
if (ctx->print_timings) {
678-
LOG_INF("image decoded (batch %d/%d) in %" PRId64 " ms\n", i_batch+1, n_img_batches, ggml_time_ms() - t1);
679-
}
680-
681-
i_batch++;
682-
}
683-
684-
n_past += mtmd_image_tokens_get_n_pos(image_tokens);
685-
*new_n_past = n_past;
686-
687-
if (mtmd_decode_use_non_causal(ctx)) {
688-
llama_set_causal_attn(lctx, true);
713+
ret = mtmd_helper_decode_image(ctx, lctx, image_tokens, embd, n_past, seq_id, n_batch, new_n_past);
714+
if (ret != 0) {
715+
LOG_ERR("failed to decode image\n");
716+
llama_batch_free(text_batch);
717+
return ret;
689718
}
690-
691719
} else {
692720
GGML_ABORT("chunk type not supported");
693721
}
@@ -903,6 +931,19 @@ llama_pos mtmd_image_tokens_get_n_pos(const mtmd_image_tokens * image_tokens) {
903931
return image_tokens->n_tokens();
904932
}
905933

934+
void mtmd_image_tokens_free(mtmd_image_tokens * image_tokens) {
935+
if (image_tokens) {
936+
delete image_tokens;
937+
}
938+
}
939+
940+
mtmd_image_tokens * mtmd_image_tokens_copy(const mtmd_image_tokens * image_tokens) {
941+
if (!image_tokens) {
942+
return nullptr;
943+
}
944+
return new mtmd_image_tokens(image_tokens->clone());
945+
}
946+
906947
// test function
907948

908949
mtmd_input_chunks * mtmd_test_create_input_chunks() {

tools/mtmd/mtmd.h

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -143,12 +143,14 @@ MTMD_API void mtmd_input_chunk_free(mtmd_input_chunk * chunk);
143143
//
144144
// the instance will be constructed via mtmd_tokenize()
145145
// it will be freed along with mtmd_input_chunk
146-
MTMD_API size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens);
147-
MTMD_API size_t mtmd_image_tokens_get_nx (const mtmd_image_tokens * image_tokens);
148-
MTMD_API size_t mtmd_image_tokens_get_ny (const mtmd_image_tokens * image_tokens);
149-
MTMD_API const char * mtmd_image_tokens_get_id (const mtmd_image_tokens * image_tokens);
146+
MTMD_API size_t mtmd_image_tokens_get_n_tokens (const mtmd_image_tokens * image_tokens);
147+
MTMD_API size_t mtmd_image_tokens_get_nx (const mtmd_image_tokens * image_tokens);
148+
MTMD_API size_t mtmd_image_tokens_get_ny (const mtmd_image_tokens * image_tokens);
149+
MTMD_API const char * mtmd_image_tokens_get_id (const mtmd_image_tokens * image_tokens);
150150
// number of temporal positions (always 1 for M-RoPE, n_tokens otherwise)
151-
MTMD_API llama_pos mtmd_image_tokens_get_n_pos (const mtmd_image_tokens * image_tokens);
151+
MTMD_API llama_pos mtmd_image_tokens_get_n_pos (const mtmd_image_tokens * image_tokens);
152+
MTMD_API mtmd_image_tokens * mtmd_image_tokens_copy (const mtmd_image_tokens * image_tokens);
153+
MTMD_API void mtmd_image_tokens_free (mtmd_image_tokens * image_tokens);
152154

153155
// tokenize an input text prompt and an image
154156
// the prompt must have the input image marker (default: "<__image__>") in it
@@ -178,6 +180,9 @@ MTMD_API int32_t mtmd_encode(mtmd_context * ctx,
178180
// get output embeddings from the last encode pass
179181
MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx);
180182

183+
// returns a copy of output embeddings from the last encode pass, of size n_embd_out
184+
MTMD_API float * mtmd_get_output_embd_copy(mtmd_context * ctx, size_t * n_embd_out);
185+
181186
/////////////////////////////////////////
182187

183188
//
@@ -231,6 +236,16 @@ MTMD_API int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
231236
bool logits_last,
232237
llama_pos * new_n_past);
233238

239+
// helper function to decode an image whose embeddings have already been calculated
240+
MTMD_API int32_t mtmd_helper_decode_image(mtmd_context *ctx,
241+
struct llama_context *lctx,
242+
const mtmd_image_tokens *image_tokens,
243+
float *embd,
244+
llama_pos n_past,
245+
llama_seq_id seq_id,
246+
int32_t n_batch,
247+
llama_pos *new_n_past);
248+
234249
/////////////////////////////////////////
235250

236251
// test function, to be used in test-mtmd-c-api.c
@@ -268,6 +283,11 @@ struct mtmd_input_chunk_deleter {
268283
};
269284
using input_chunk_ptr = std::unique_ptr<mtmd_input_chunk, mtmd_input_chunk_deleter>;
270285

286+
struct mtmd_image_tokens_deleter {
287+
void operator()(mtmd_image_tokens * val) { mtmd_image_tokens_free(val); }
288+
};
289+
using image_tokens_ptr = std::unique_ptr<mtmd_image_tokens, mtmd_image_tokens_deleter>;
290+
271291
struct bitmap {
272292
bitmap_ptr ptr;
273293
bitmap() : ptr(nullptr) {}

0 commit comments

Comments
 (0)