@@ -44,18 +44,30 @@ struct mtmd_image_tokens_data {
4444 clip_image_f32_batch_ptr batch_f32; // preprocessed image patches
4545};
4646
47- mtmd_context_ptr mtmd_init_from_file (const char * mmproj_fname,
47+ struct mtmd_image_tokens {
48+ uint32_t nx; // number of tokens in x direction
49+ uint32_t ny; // number of tokens in y direction
50+ uint32_t n_tokens () const { return nx * ny; }
51+ clip_image_f32_batch_ptr batch_f32; // preprocessed image patches
52+ };
53+
54+ mtmd_context * mtmd_init_from_file (const char * mmproj_fname,
4855 const struct llama_model * text_model,
4956 const struct mtmd_context_params ctx_params) {
5057 try {
51- auto ctx = std::make_shared<mtmd_context>(mmproj_fname, text_model, ctx_params);
52- return ctx;
58+ return new mtmd_context (mmproj_fname, text_model, ctx_params);
5359 } catch (const std::exception & e) {
5460 LOG_ERR (" %s: error: %s\n " , __func__, e.what ());
5561 return nullptr ;
5662 }
5763}
5864
65+ void mtmd_free (mtmd_context * ctx) {
66+ if (ctx) {
67+ delete ctx;
68+ }
69+ }
70+
5971int32_t mtmd_bitmap_init_from_file (const char * fname, mtmd_bitmap & output) {
6072 clip_image_u8_ptr img_u8 (clip_image_u8_init ());
6173 bool ok = clip_image_load_from_file (fname, img_u8.get ());
@@ -89,10 +101,10 @@ static std::vector<llama_token> mtmd_tokenize_text_internal(
89101 return result;
90102}
91103
92- int32_t mtmd_tokenize (mtmd_context_ptr & ctx,
93- std::vector<mtmd_input_chunk> & output ,
94- const mtmd_input_text & text,
95- const std::vector<mtmd_bitmap> & bitmaps) {
104+ mtmd_input_chunks * mtmd_tokenize (mtmd_context * ctx,
105+ const mtmd_input_text & text ,
106+ const std::vector<mtmd_bitmap> & bitmaps) {
107+ mtmd_input_chunks * output = new mtmd_input_chunks;
96108 auto vocab = llama_model_get_vocab (ctx->text_model );
97109
98110 std::string prompt_modified (text.text );
@@ -107,8 +119,8 @@ int32_t mtmd_tokenize(mtmd_context_ptr & ctx,
107119 }
108120
109121 std::vector<std::string> parts = string_split_str (text.text , ctx->image_marker );
110- output. clear ();
111- output. reserve (parts.size ());
122+ output-> clear ();
123+ output-> reserve (parts.size ());
112124
113125 size_t i_img = 0 ;
114126
@@ -119,18 +131,19 @@ int32_t mtmd_tokenize(mtmd_context_ptr & ctx,
119131 if (tokens.empty ()) {
120132 continue ;
121133 }
122- output. push_back ( {
123- LLAVA2_INPUT_CHUNK_TYPE_TEXT ,
134+ mtmd_input_chunk chunk {
135+ MTMD_INPUT_CHUNK_TYPE_TEXT ,
124136 std::move (tokens),
125137 {},
126- });
138+ };
139+ output->emplace_back (std::move (chunk));
127140
128141 if (&parts.back () != &part) {
129142 // add image token to middle of 2 parts
130143
131144 if (i_img >= bitmaps.size ()) {
132145 LOG_ERR (" %s: error: not enough images for %d parts\n " , __func__, (int )parts.size ());
133- return 2 ;
146+ return nullptr ;
134147 }
135148
136149 // shim layer
@@ -145,54 +158,58 @@ int32_t mtmd_tokenize(mtmd_context_ptr & ctx,
145158 bool ok = clip_image_preprocess (ctx->ctx_clip , img_u8.get (), batch_f32.get ());
146159 if (!ok) {
147160 LOG_ERR (" Unable to preprocess image\n " );
148- return 1 ;
161+ return nullptr ;
149162 }
150163
151- mtmd_image_tokens image_tokens;
152- image_tokens.nx = 0 ; // TODO
153- image_tokens.ny = 0 ; // TODO
154- image_tokens.n_tokens = clip_n_patches (ctx->ctx_clip ); // TODO @ngxson : use clip_n_patches_by_image
155- image_tokens.data = std::unique_ptr<mtmd_image_tokens_data>(
156- new mtmd_image_tokens_data{
157- std::move (batch_f32),
158- }
159- );
160-
161- output.push_back ({
162- LLAVA2_INPUT_CHUNK_TYPE_IMAGE,
164+ mtmd_image_tokens * image_tokens = new mtmd_image_tokens;
165+ image_tokens->nx = clip_n_patches (ctx->ctx_clip ); // TODO @ngxson : use clip_n_patches_by_image
166+ image_tokens->ny = 1 ; // TODO
167+ image_tokens->batch_f32 = std::move (batch_f32);
168+
169+ mtmd_input_chunk chunk{
170+ MTMD_INPUT_CHUNK_TYPE_IMAGE,
163171 {},
164- std::move (image_tokens),
165- });
172+ image_tokens,
173+ };
174+ output->emplace_back (std::move (chunk));
166175 i_img++;
167176 }
168177 }
169178
170- return 0 ;
179+ return output;
180+ }
181+
182+ void mtmd_input_chunks_free (mtmd_input_chunks * chunks) {
183+ for (auto & chunk : *chunks) {
184+ if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE && chunk.tokens_image ) {
185+ delete chunk.tokens_image ;
186+ }
187+ }
188+ delete chunks;
171189}
172190
173- LLAVA2_API int32_t mtmd_encode (mtmd_context_ptr & ctx,
174- const mtmd_image_tokens & image_tokens) {
191+ int32_t mtmd_encode (mtmd_context * ctx, const mtmd_image_tokens * image_tokens) {
175192 int n_mmproj_embd = clip_n_mmproj_embd (ctx->ctx_clip );
176- ctx->image_embd_v .resize (image_tokens. n_tokens * n_mmproj_embd);
193+ ctx->image_embd_v .resize (image_tokens-> n_tokens () * n_mmproj_embd);
177194 bool ok = clip_image_batch_encode (
178195 ctx->ctx_clip ,
179196 ctx->n_threads ,
180- image_tokens. data ->batch_f32 .get (),
197+ image_tokens->batch_f32 .get (),
181198 ctx->image_embd_v .data ());
182199 return ok ? 0 : 1 ;
183200}
184201
185- LLAVA2_API float * mtmd_get_output_embd (mtmd_context_ptr & ctx) {
202+ float * mtmd_get_output_embd (mtmd_context * ctx) {
186203 return ctx->image_embd_v .data ();
187204}
188205
189- size_t mtmd_helper_get_n_tokens (std::vector<mtmd_input_chunk> & chunks) {
206+ size_t mtmd_helper_get_n_tokens (mtmd_input_chunks * chunks) {
190207 size_t n_tokens = 0 ;
191- for (auto & chunk : chunks) {
192- if (chunk.type == LLAVA2_INPUT_CHUNK_TYPE_TEXT ) {
208+ for (auto & chunk : * chunks) {
209+ if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT ) {
193210 n_tokens += chunk.tokens_text .size ();
194- } else if (chunk.type == LLAVA2_INPUT_CHUNK_TYPE_IMAGE ) {
195- n_tokens += chunk.tokens_image . n_tokens ;
211+ } else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE ) {
212+ n_tokens += chunk.tokens_image -> n_tokens () ;
196213 } else {
197214 GGML_ASSERT (false && " chunk type not supported" );
198215 }
@@ -235,19 +252,19 @@ struct decode_embd_batch {
235252 }
236253};
237254
238- int32_t mtmd_helper_eval (mtmd_context_ptr & ctx,
255+ int32_t mtmd_helper_eval (mtmd_context * ctx,
239256 llama_context * lctx,
240- std::vector<mtmd_input_chunk> & chunks,
257+ mtmd_input_chunks * chunks,
241258 llama_pos pos0,
242259 llama_seq_id seq_id,
243260 int32_t n_batch) {
244261 int32_t ret;
245262 llama_pos n_past = pos0;
246263 llama_batch text_batch = llama_batch_init (n_batch, 0 , 1 );
247264
248- for (auto & chunk : chunks) {
249- bool is_last = &chunk == &chunks. back ();
250- if (chunk.type == LLAVA2_INPUT_CHUNK_TYPE_TEXT ) {
265+ for (auto & chunk : * chunks) {
266+ bool is_last = &chunk == &chunks-> back ();
267+ if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT ) {
251268 // TODO @ngxson : may need to split into smaller batches
252269 text_batch.n_tokens = chunk.tokens_text .size ();
253270 for (size_t i = 0 ; i < chunk.tokens_text .size (); i++) {
@@ -268,8 +285,9 @@ int32_t mtmd_helper_eval(mtmd_context_ptr & ctx,
268285 return ret;
269286 }
270287
271- } else if (chunk.type == LLAVA2_INPUT_CHUNK_TYPE_IMAGE ) {
288+ } else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE ) {
272289 GGML_ASSERT (!is_last && " logits for last image chunk is not yet support" );
290+ GGML_ASSERT (chunk.tokens_image != nullptr );
273291 int64_t t0 = ggml_time_ms ();
274292 if (ctx->print_timings ) {
275293 LOG_INF (" encoding image...\n " );
@@ -284,7 +302,7 @@ int32_t mtmd_helper_eval(mtmd_context_ptr & ctx,
284302 LOG_INF (" image encoded in %" PRId64 " ms\n " , ggml_time_ms () - t0);
285303 }
286304
287- int32_t n_tokens = chunk.tokens_image . n_tokens ;
305+ int32_t n_tokens = chunk.tokens_image -> n_tokens () ;
288306 float * embd = mtmd_get_output_embd (ctx);
289307 decode_embd_batch batch_img (embd, n_tokens, n_past, 0 );
290308 int64_t t1 = ggml_time_ms ();
0 commit comments