Skip to content

Commit a363251

Browse files
committed
qwen2vl: use llama_batch_ext_set_pos
1 parent ba79369 commit a363251

File tree

3 files changed

+19
-12
lines changed

3 files changed

+19
-12
lines changed

examples/llava/qwen2vl-cli.cpp

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -66,18 +66,11 @@ static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct lla
6666
memcpy(&batch_mrope_pos[n_eval * 2], &mrope_pos[img_tokens * 2 + processed], n_eval * sizeof(llama_pos));
6767
memcpy(&batch_mrope_pos[n_eval * 3], &mrope_pos[img_tokens * 3 + processed], n_eval * sizeof(llama_pos));
6868

69-
// TODO: move this to llama_batch_ext API
70-
llama_batch batch = {
71-
int32_t(n_eval), // n_tokens
72-
nullptr, // token
73-
(image_embed->embed+i*n_embd), // embed
74-
batch_mrope_pos.data(), // pos
75-
nullptr, // n_seq_id
76-
nullptr, // seq_id
77-
nullptr, // logits
78-
};
79-
80-
if (llama_decode(ctx_llama, batch)) {
69+
float * batch_embd = image_embed->embed+i*n_embd;
70+
llama_batch_ext_ptr batch(llama_batch_ext_init_from_embd(batch_embd, n_eval, n_embd, 0, 0));
71+
llama_batch_ext_set_pos(batch.get(), batch_mrope_pos.data(), n_eval);
72+
73+
if (llama_decode_ext(ctx_llama, batch.get())) {
8174
LOG_ERR("%s : failed to eval\n", __func__);
8275
return false;
8376
}

include/llama.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -950,6 +950,12 @@ extern "C" {
950950
int32_t pos0,
951951
int32_t seq_id);
952952

953+
// Set arbitrary token to the embeddings batch
954+
// Note: this is only to be used in conjunction with llama_batch_ext_init_from_embd()
955+
// n_pos must match the n_tokens of the batch
956+
// Returns -1 if n_pos does not match the n_tokens of the batch
957+
LLAMA_API int32_t llama_batch_ext_set_pos(struct llama_batch_ext * batch, llama_pos * pos, size_t n_pos);
958+
953959
// Get the number of tokens in the batch
954960
LLAMA_API int32_t llama_batch_ext_get_n_tokens(const struct llama_batch_ext * batch);
955961

src/llama-batch.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,14 @@ struct llama_batch_ext * llama_batch_ext_init_from_embd(
405405
return batch;
406406
}
407407

408+
int32_t llama_batch_ext_set_pos(struct llama_batch_ext * batch, llama_pos * pos, size_t n_pos) {
409+
if (batch->n_tokens != n_pos) {
410+
return -1;
411+
}
412+
memcpy(batch->pos, pos, n_pos * sizeof(llama_pos));
413+
return 0;
414+
}
415+
408416
int32_t llama_batch_ext_get_n_tokens(const struct llama_batch_ext * batch) {
409417
return batch->n_tokens;
410418
}

0 commit comments

Comments
 (0)