Skip to content

Commit ba79369

Browse files
committed
fix llama_batch_ext_init_from_embd
1 parent 07d84fa commit ba79369

File tree

4 files changed

+14
-10
lines changed

4 files changed

+14
-10
lines changed

examples/llava/gemma3-cli.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ static int eval_image(gemma3_context & ctx, std::string & fname) {
148148
int64_t t1 = ggml_time_ms();
149149
eval_text(ctx, "<start_of_image>");
150150
llama_set_causal_attn(ctx.lctx, false);
151-
llama_batch_ext_ptr batch_img(llama_batch_ext_init_from_embd(image_embd_v.data(), n_tokens, ctx.n_past, 0));
151+
llama_batch_ext_ptr batch_img(llama_batch_ext_init_from_embd(image_embd_v.data(), n_tokens, n_embd, ctx.n_past, 0));
152152
if (llama_decode_ext(ctx.lctx, batch_img.get())) {
153153
LOG_ERR("failed to decode image\n");
154154
return 1;

examples/llava/llava.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,7 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
448448
n_eval = n_batch;
449449
}
450450
float * embd = image_embed->embed+i*n_embd;
451-
llama_batch_ext_ptr batch(llama_batch_ext_init_from_embd(embd, n_eval, 0, 0));
451+
llama_batch_ext_ptr batch(llama_batch_ext_init_from_embd(embd, n_eval, n_embd, 0, 0));
452452
if (llama_decode_ext(ctx_llama, batch.get())) {
453453
LOG_ERR("%s : failed to eval\n", __func__);
454454
return false;

include/llama.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -938,11 +938,14 @@ extern "C" {
938938
bool output_last);
939939

940940
// Same with llama_batch_init, but initializes the batch with the provided raw embeddings
941+
// Size of embd should be n_tokens * n_embd
942+
// n_embd is the number of embeddings per token, can be obtained from llama_model_n_embd()
941943
// First token will be at position pos0
942944
// The sequence ID will be fixed to seq_id
943945
// The batch has to be freed with llama_batch_ext_free()
944946
LLAMA_API struct llama_batch_ext * llama_batch_ext_init_from_embd(
945947
float * embd,
948+
size_t n_tokens,
946949
size_t n_embd,
947950
int32_t pos0,
948951
int32_t seq_id);

src/llama-batch.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ struct llama_batch_ext * llama_batch_ext_init_from_text(
353353
return batch;
354354
}
355355

356-
static struct llama_batch_ext * llama_batch_ext_init_impl(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
356+
static struct llama_batch_ext * llama_batch_ext_init_impl(int32_t n_tokens_alloc, int32_t n_embd, int32_t n_seq_max) {
357357
llama_batch_ext * batch = new llama_batch_ext{
358358
/*n_tokens =*/ 0,
359359
/*max_tokens =*/ n_tokens_alloc,
@@ -366,8 +366,8 @@ static struct llama_batch_ext * llama_batch_ext_init_impl(int32_t n_tokens_alloc
366366
/*logits =*/ nullptr,
367367
};
368368

369-
if (embd) {
370-
batch->embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
369+
if (n_embd) {
370+
batch->embd = (float *) malloc(sizeof(float) * n_tokens_alloc * n_embd);
371371
} else {
372372
batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
373373
}
@@ -391,14 +391,15 @@ struct llama_batch_ext * llama_batch_ext_init(int32_t n_tokens_alloc, int32_t n_
391391

392392
struct llama_batch_ext * llama_batch_ext_init_from_embd(
393393
float * embd,
394+
size_t n_tokens,
394395
size_t n_embd,
395396
int32_t pos0,
396397
int32_t seq_id) {
397-
struct llama_batch_ext * batch = llama_batch_ext_init_impl(0, n_embd, 1);
398-
memcpy(batch->embd, embd, n_embd * sizeof(float));
399-
for (size_t i = 0; i < n_embd; i++) {
400-
batch->pos [i] = pos0 + i;
401-
batch->n_seq_id[i] = 1;
398+
struct llama_batch_ext * batch = llama_batch_ext_init_impl(n_tokens, n_embd, 1);
399+
memcpy(batch->embd, embd, n_tokens * n_embd * sizeof(float));
400+
for (size_t i = 0; i < n_tokens; i++) {
401+
batch->pos [i] = pos0 + i;
402+
batch->n_seq_id[i] = 1;
402403
batch->seq_id [i][0] = seq_id;
403404
}
404405
return batch;

0 commit comments

Comments
 (0)