Skip to content

Commit a46b6db

Browse files
committed
mtmd : add more api around mtmd_image_tokens
1 parent b6930eb commit a46b6db

File tree

2 files changed

+54
-8
lines changed

2 files changed

+54
-8
lines changed

examples/llava/mtmd.cpp

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -166,15 +166,36 @@ mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx,
166166
return output;
167167
}
168168

169-
void mtmd_input_chunks_free(mtmd_input_chunks * chunks) {
170-
for (auto & chunk : *chunks) {
171-
if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE && chunk.tokens_image) {
172-
delete chunk.tokens_image;
169+
void mtmd_image_tokens_free(mtmd_image_tokens * image_tokens) {
170+
if (image_tokens) {
171+
delete image_tokens;
172+
}
173+
}
174+
175+
void mtmd_input_chunks_free(mtmd_input_chunks * chunks, bool free_images) {
176+
if (free_images) {
177+
for (auto & chunk : *chunks) {
178+
if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE && chunk.tokens_image) {
179+
mtmd_image_tokens_free(chunk.tokens_image);
180+
chunk.tokens_image = nullptr;
181+
}
173182
}
174183
}
175184
delete chunks;
176185
}
177186

187+
size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens) {
188+
return image_tokens->n_tokens();
189+
}
190+
191+
size_t mtmd_image_tokens_get_nx(const mtmd_image_tokens * image_tokens) {
192+
return image_tokens->nx;
193+
}
194+
195+
size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens) {
196+
return image_tokens->ny;
197+
}
198+
178199
int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) {
179200
int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip);
180201
ctx->image_embd_v.resize(image_tokens->n_tokens() * n_mmproj_embd);
@@ -289,7 +310,7 @@ int32_t mtmd_helper_eval(mtmd_context * ctx,
289310
LOG_INF("image encoded in %" PRId64 " ms\n", ggml_time_ms() - t0);
290311
}
291312

292-
int32_t n_tokens = chunk.tokens_image->n_tokens();
313+
int32_t n_tokens = mtmd_image_tokens_get_n_tokens(chunk.tokens_image);
293314
float * embd = mtmd_get_output_embd(ctx);
294315
decode_embd_batch batch_img(embd, n_tokens, n_past, 0);
295316
int64_t t1 = ggml_time_ms();
@@ -339,3 +360,11 @@ int32_t mtmd_helper_bitmap_init_from_file(const char * fname, mtmd_bitmap & outp
339360
std::memcpy(output.data.data(), data, output.nx * output.ny * 3);
340361
return 0;
341362
}
363+
364+
bool mtmd_decode_use_non_causal(mtmd_context * ctx) {
365+
projector_type proj_type = clip_get_projector_type(ctx->ctx_clip);
366+
if (proj_type == PROJECTOR_TYPE_GEMMA3) {
367+
return true;
368+
}
369+
return false;
370+
}

examples/llava/mtmd.h

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,20 @@ MTMD_API void mtmd_free(mtmd_context * ctx);
8181
// 2. (image tokens)
8282
// 3. "<end_of_image>\ndescribe it in detail."
8383
// number of bitmaps must be equal to the number of image markers in the prompt
84+
// the returned value must be freed using mtmd_input_chunks_free()
8485
// this function is thread-safe (shared ctx)
8586
MTMD_API mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx,
8687
const mtmd_input_text & text,
8788
const std::vector<mtmd_bitmap> & bitmaps);
8889

89-
// free image chunk data
90-
MTMD_API void mtmd_input_chunks_free(mtmd_input_chunks * chunks);
90+
// if free_images = true, free the image tokens ; otherwise, you must free them using mtmd_image_free()
91+
MTMD_API void mtmd_input_chunks_free(mtmd_input_chunks * chunks, bool free_images);
92+
93+
// access mtmd_image_tokens
94+
MTMD_API size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens);
95+
MTMD_API size_t mtmd_image_tokens_get_nx(const mtmd_image_tokens * image_tokens);
96+
MTMD_API size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens);
97+
MTMD_API void mtmd_image_tokens_free(mtmd_image_tokens * image_tokens);
9198

9299
// returns 0 on success
93100
MTMD_API int32_t mtmd_encode(mtmd_context * ctx,
@@ -96,6 +103,11 @@ MTMD_API int32_t mtmd_encode(mtmd_context * ctx,
96103
// get output embeddings from the last encode pass
97104
MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx);
98105

106+
// whether we need to set non-causal mask before llama_decode
107+
MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx);
108+
109+
110+
99111
//
100112
// helper functions (can be implemented based on other functions)
101113
//
@@ -133,10 +145,15 @@ struct mtmd_context_deleter {
133145
using mtmd_context_ptr = std::unique_ptr<mtmd_context, mtmd_context_deleter>;
134146

135147
struct mtmd_input_chunks_deleter {
136-
void operator()(mtmd_input_chunks * val) { mtmd_input_chunks_free(val); }
148+
void operator()(mtmd_input_chunks * val) { mtmd_input_chunks_free(val, true); }
137149
};
138150
using mtmd_input_chunks_ptr = std::unique_ptr<mtmd_input_chunks, mtmd_input_chunks_deleter>;
139151

152+
struct mtmd_image_tokens_deleter {
153+
void operator()(mtmd_image_tokens * val) { mtmd_image_tokens_free(val); }
154+
};
155+
using mtmd_image_tokens_ptr = std::unique_ptr<mtmd_image_tokens, mtmd_image_tokens_deleter>;
156+
140157
#else
141158

142159
static_assert(false && "C header is not yet supported by this library");

0 commit comments

Comments
 (0)