@@ -102,15 +102,20 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
102102
103103 std::string prompt_modified (text.text );
104104 std::string marker_modified (ctx->image_marker );
105- projector_type proj_type = clip_get_projector_type (ctx->ctx_clip );
106105 // a bit hacky here, but works for now
107106 // for some models, we need to add prefix and suffix to the image embeddings
108- if (proj_type == PROJECTOR_TYPE_GEMMA3) {
107+ if (clip_is_gemma3 (ctx->ctx_clip )) {
108+ // gemma 3
109109 // <start_of_image> ... (image embeddings) ... <end_of_image>
110110 marker_modified = " <start_of_image>" + ctx->image_marker + " <end_of_image>" ;
111111 string_replace_all (prompt_modified, ctx->image_marker , marker_modified);
112112 }
113113
114+ // llava-1.5, llava-1.6, Yi-VL, Yi-34B, granite: don't need to add prefix and suffix
115+ // for glm-edge, we don't need to add because the tokens are already in the returned embeddings
116+
117+ // TODO @ngxson : glm-edge : remove BOI / EOI tokens embeddings, decode them as normal tokens
118+
114119 std::vector<std::string> parts = string_split_str (prompt_modified, ctx->image_marker );
115120 output.clear ();
116121 output.reserve (parts.size ());
@@ -155,11 +160,20 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
155160 }
156161
157162 mtmd_image_tokens_ptr image_tokens (new mtmd_image_tokens);
158- image_tokens->nx = clip_n_patches (ctx->ctx_clip ); // TODO @ngxson : use clip_n_patches_by_image
163+ image_tokens->nx = clip_n_patches (ctx->ctx_clip ) * batch_f32. entries . size () ; // TODO @ngxson : use clip_n_patches_by_image
159164 image_tokens->ny = 1 ; // TODO
160165 image_tokens->batch_f32 = std::move (batch_f32);
161166 image_tokens->id = bitmaps[i_img].id ; // optional
162167
168+ LOG_DBG (" image_tokens->nx = %d\n " , image_tokens->nx );
169+ LOG_DBG (" image_tokens->ny = %d\n " , image_tokens->ny );
170+ LOG_DBG (" batch_f32 size = %d\n " , (int )image_tokens->batch_f32 .entries .size ());
171+
172+ if (clip_is_glm (ctx->ctx_clip )) {
173+ // glm-edge
174+ image_tokens->nx += 2 ; // add 2 for the begin_of_image and end_of_image token embeddings
175+ }
176+
163177 mtmd_input_chunk chunk{
164178 MTMD_INPUT_CHUNK_TYPE_IMAGE,
165179 {},
@@ -198,11 +212,27 @@ std::string mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens) {
198212int32_t mtmd_encode (mtmd_context * ctx, const mtmd_image_tokens * image_tokens) {
199213 int n_mmproj_embd = clip_n_mmproj_embd (ctx->ctx_clip );
200214 ctx->image_embd_v .resize (image_tokens->n_tokens () * n_mmproj_embd);
201- bool ok = clip_image_batch_encode (
202- ctx->ctx_clip ,
203- ctx->n_threads ,
204- &image_tokens->batch_f32 ,
205- ctx->image_embd_v .data ());
215+ bool ok = false ;
216+
217+ if (clip_is_llava (ctx->ctx_clip )) {
218+ // TODO @ngxson : llava does not support batched encoding ; this should be fixed inside clip_image_batch_encode()
219+ const auto & entries = image_tokens->batch_f32 .entries ;
220+ for (size_t i = 0 ; i < entries.size (); i++) {
221+ int n_tokens_per_image = clip_n_patches (ctx->ctx_clip );
222+ ok = clip_image_encode (
223+ ctx->ctx_clip ,
224+ ctx->n_threads ,
225+ entries[i].get (),
226+ ctx->image_embd_v .data () + i*n_mmproj_embd*n_tokens_per_image);
227+ }
228+ } else {
229+ ok = clip_image_batch_encode (
230+ ctx->ctx_clip ,
231+ ctx->n_threads ,
232+ &image_tokens->batch_f32 ,
233+ ctx->image_embd_v .data ());
234+ }
235+
206236 return ok ? 0 : 1 ;
207237}
208238
@@ -268,28 +298,31 @@ int32_t mtmd_helper_eval(mtmd_context * ctx,
268298 int32_t ret;
269299 llama_pos n_past = pos0;
270300 llama_batch text_batch = llama_batch_init (n_batch, 0 , 1 );
301+ int n_mmproj_embd = clip_n_mmproj_embd (ctx->ctx_clip );
271302
272303 for (auto & chunk : chunks) {
273304 bool is_last = &chunk == &chunks.back ();
274305 if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
275- // TODO @ngxson : may need to split into smaller batches
276306 text_batch.n_tokens = chunk.tokens_text .size ();
277- for (size_t i = 0 ; i < chunk.tokens_text .size (); i++) {
278- text_batch.token [i] = chunk.tokens_text [i];
279- text_batch.pos [i] = n_past++;
280- text_batch.n_seq_id [i] = 1 ;
281- text_batch.seq_id [i][0 ] = seq_id;
282- text_batch.logits [i] = false ;
283- }
284- if (is_last) {
285- // always get logits for last input chunk
286- text_batch.logits [text_batch.n_tokens - 1 ] = true ;
287- }
288- ret = llama_decode (lctx, text_batch);
289- if (ret != 0 ) {
290- LOG_ERR (" failed to decode text\n " );
291- llama_batch_free (text_batch);
292- return ret;
307+ size_t i = 0 ;
308+ while (i < chunk.tokens_text .size ()) { // split into batches
309+ for (; i < chunk.tokens_text .size () && text_batch.n_tokens < n_batch; i++) {
310+ text_batch.token [i] = chunk.tokens_text [i];
311+ text_batch.pos [i] = n_past++;
312+ text_batch.n_seq_id [i] = 1 ;
313+ text_batch.seq_id [i][0 ] = seq_id;
314+ text_batch.logits [i] = false ;
315+ }
316+ if (is_last) {
317+ // always get logits for last input chunk
318+ text_batch.logits [text_batch.n_tokens - 1 ] = true ;
319+ }
320+ ret = llama_decode (lctx, text_batch);
321+ if (ret != 0 ) {
322+ LOG_ERR (" failed to decode text\n " );
323+ llama_batch_free (text_batch);
324+ return ret;
325+ }
293326 }
294327
295328 } else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
@@ -310,20 +343,42 @@ int32_t mtmd_helper_eval(mtmd_context * ctx,
310343 }
311344
312345 int32_t n_tokens = mtmd_image_tokens_get_n_tokens (chunk.tokens_image .get ());
346+ int32_t i_batch = 0 ;
347+ int32_t n_img_batches = GGML_PAD (n_tokens, n_batch) / n_batch;
313348 float * embd = mtmd_get_output_embd (ctx);
314- decode_embd_batch batch_img (embd, n_tokens, n_past, 0 );
315- int64_t t1 = ggml_time_ms ();
316- ret = llama_decode (lctx, batch_img.batch );
317- if (ret != 0 ) {
318- LOG_ERR (" failed to decode image\n " );
319- llama_batch_free (text_batch);
320- return ret;
349+
350+ if (mtmd_decode_use_non_causal (ctx)) {
351+ llama_set_causal_attn (lctx, false );
321352 }
322- if (ctx->print_timings ) {
323- LOG_INF (" image decoded in %" PRId64 " ms\n " , ggml_time_ms () - t1);
353+
354+ while (i_batch < n_img_batches) { // split into batches
355+ int32_t pos_offset = i_batch*n_batch;
356+ int32_t n_tokens_batch = std::min (n_batch, n_tokens - pos_offset);
357+ float * embd_batch = embd + pos_offset*n_mmproj_embd;
358+ decode_embd_batch batch_img (embd_batch, n_tokens_batch, n_past, 0 );
359+
360+ printf (" decoding image batch %d/%d, n_tokens_batch = %d\n " , i_batch+1 , n_img_batches, n_tokens_batch);
361+
362+ int64_t t1 = ggml_time_ms ();
363+ ret = llama_decode (lctx, batch_img.batch );
364+ if (ret != 0 ) {
365+ LOG_ERR (" failed to decode image\n " );
366+ llama_set_causal_attn (lctx, true ); // restore causal attn
367+ llama_batch_free (text_batch);
368+ return ret;
369+ }
370+
371+ if (ctx->print_timings ) {
372+ LOG_INF (" image decoded (batch %d/%d) in %" PRId64 " ms\n " , i_batch+1 , n_img_batches, ggml_time_ms () - t1);
373+ }
374+
375+ i_batch++;
376+ n_past += n_tokens_batch;
324377 }
325378
326- n_past += n_tokens;
379+ if (mtmd_decode_use_non_causal (ctx)) {
380+ llama_set_causal_attn (lctx, true );
381+ }
327382
328383 } else {
329384 GGML_ASSERT (false && " chunk type not supported" );
0 commit comments