@@ -591,21 +591,29 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
591591 if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
592592 size_t n_tokens;
593593 const auto tokens = mtmd_input_chunk_get_tokens_text (chunk, &n_tokens);
594- text_batch.n_tokens = n_tokens;
595594 LOG_DBG (" decoding text chunk, n_tokens = %zu\n " , n_tokens);
596595 size_t i = 0 ;
597596 while (i < n_tokens) { // split into batches
597+ text_batch.n_tokens = 0 ; // clear the batch
598598 for (; i < n_tokens && text_batch.n_tokens < n_batch; i++) {
599+ text_batch.n_tokens ++;
599600 text_batch.token [i] = tokens[i];
600601 text_batch.pos [i] = n_past++;
601602 text_batch.n_seq_id [i] = 1 ;
602603 text_batch.seq_id [i][0 ] = seq_id;
603604 text_batch.logits [i] = false ;
604605 }
605- bool is_last_batch = (i == n_tokens);
606- if (logits_last && is_last_batch ) {
606+ bool is_last_token = (i == n_tokens);
607+ if (logits_last && is_last_token ) {
607608 text_batch.logits [text_batch.n_tokens - 1 ] = true ;
608609 }
610+ ret = llama_decode (lctx, text_batch);
611+ if (ret != 0 ) {
612+ LOG_ERR (" failed to decode text\n " );
613+ llama_batch_free (text_batch);
614+ return ret;
615+ }
616+ *new_n_past += text_batch.n_tokens ;
609617 }
610618
611619 } else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
@@ -697,10 +705,10 @@ int32_t mtmd_helper_eval_chunks(mtmd_context * ctx,
697705 }
698706
699707 for (size_t i = 0 ; i < n_chunks; i++) {
700- bool is_last_chunk = (i == n_chunks - 1 );
708+ bool chunk_logits_last = (i == n_chunks - 1 ) && logits_last ;
701709 auto chunk = mtmd_input_chunks_get (chunks, i);
702710
703- int32_t res = mtmd_helper_eval_chunk_single (ctx, lctx, chunk, n_past, seq_id, n_batch, is_last_chunk && logits_last , &n_past);
711+ int32_t res = mtmd_helper_eval_chunk_single (ctx, lctx, chunk, n_past, seq_id, n_batch, chunk_logits_last , &n_past);
704712 if (res != 0 ) {
705713 LOG_ERR (" failed to eval chunk %zu\n " , i);
706714 return res;
0 commit comments