Skip to content

Commit 863db31

Browse files
committed
fix merging issue
1 parent a230804 commit 863db31

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

tools/llava/mtmd-cli.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,8 @@ static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, bool add_
211211
return 1;
212212
}
213213

214+
ctx.bitmaps.entries.clear();
215+
214216
llama_pos new_n_past;
215217
if (mtmd_helper_eval_chunks(ctx.ctx_vision.get(),
216218
ctx.lctx, // lctx

tools/llava/mtmd.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -591,21 +591,29 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
591591
if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
592592
size_t n_tokens;
593593
const auto tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens);
594-
text_batch.n_tokens = n_tokens;
595594
LOG_DBG("decoding text chunk, n_tokens = %zu\n", n_tokens);
596595
size_t i = 0;
597596
while (i < n_tokens) { // split into batches
597+
text_batch.n_tokens = 0; // clear the batch
598598
for (; i < n_tokens && text_batch.n_tokens < n_batch; i++) {
599+
text_batch.n_tokens++;
599600
text_batch.token [i] = tokens[i];
600601
text_batch.pos [i] = n_past++;
601602
text_batch.n_seq_id[i] = 1;
602603
text_batch.seq_id [i][0] = seq_id;
603604
text_batch.logits [i] = false;
604605
}
605-
bool is_last_batch = (i == n_tokens);
606-
if (logits_last && is_last_batch) {
606+
bool is_last_token = (i == n_tokens);
607+
if (logits_last && is_last_token) {
607608
text_batch.logits[text_batch.n_tokens - 1] = true;
608609
}
610+
ret = llama_decode(lctx, text_batch);
611+
if (ret != 0) {
612+
LOG_ERR("failed to decode text\n");
613+
llama_batch_free(text_batch);
614+
return ret;
615+
}
616+
*new_n_past += text_batch.n_tokens;
609617
}
610618

611619
} else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
@@ -697,10 +705,10 @@ int32_t mtmd_helper_eval_chunks(mtmd_context * ctx,
697705
}
698706

699707
for (size_t i = 0; i < n_chunks; i++) {
700-
bool is_last_chunk = (i == n_chunks - 1);
708+
bool chunk_logits_last = (i == n_chunks - 1) && logits_last;
701709
auto chunk = mtmd_input_chunks_get(chunks, i);
702710

703-
int32_t res = mtmd_helper_eval_chunk_single(ctx, lctx, chunk, n_past, seq_id, n_batch, is_last_chunk && logits_last, &n_past);
711+
int32_t res = mtmd_helper_eval_chunk_single(ctx, lctx, chunk, n_past, seq_id, n_batch, chunk_logits_last, &n_past);
704712
if (res != 0) {
705713
LOG_ERR("failed to eval chunk %zu\n", i);
706714
return res;

0 commit comments

Comments
 (0)