@@ -353,7 +353,7 @@ struct llama_batch_ext * llama_batch_ext_init_from_text(
353353 return batch;
354354}
355355
356- static struct llama_batch_ext * llama_batch_ext_init_impl (int32_t n_tokens_alloc, int32_t embd , int32_t n_seq_max) {
356+ static struct llama_batch_ext * llama_batch_ext_init_impl (int32_t n_tokens_alloc, int32_t n_embd , int32_t n_seq_max) {
357357 llama_batch_ext * batch = new llama_batch_ext{
358358 /* n_tokens =*/ 0 ,
359359 /* max_tokens =*/ n_tokens_alloc,
@@ -366,8 +366,8 @@ static struct llama_batch_ext * llama_batch_ext_init_impl(int32_t n_tokens_alloc
366366 /* logits =*/ nullptr ,
367367 };
368368
369- if (embd ) {
370- batch->embd = (float *) malloc (sizeof (float ) * n_tokens_alloc * embd );
369+ if (n_embd ) {
370+ batch->embd = (float *) malloc (sizeof (float ) * n_tokens_alloc * n_embd );
371371 } else {
372372 batch->token = (llama_token *) malloc (sizeof (llama_token) * n_tokens_alloc);
373373 }
@@ -391,14 +391,15 @@ struct llama_batch_ext * llama_batch_ext_init(int32_t n_tokens_alloc, int32_t n_
391391
392392struct llama_batch_ext * llama_batch_ext_init_from_embd (
393393 float * embd,
394+ size_t n_tokens,
394395 size_t n_embd,
395396 int32_t pos0,
396397 int32_t seq_id) {
397- struct llama_batch_ext * batch = llama_batch_ext_init_impl (0 , n_embd, 1 );
398- memcpy (batch->embd , embd, n_embd * sizeof (float ));
399- for (size_t i = 0 ; i < n_embd ; i++) {
400- batch->pos [i] = pos0 + i;
401- batch->n_seq_id [i] = 1 ;
398+ struct llama_batch_ext * batch = llama_batch_ext_init_impl (n_tokens , n_embd, 1 );
399+ memcpy (batch->embd , embd, n_tokens * n_embd * sizeof (float ));
400+ for (size_t i = 0 ; i < n_tokens ; i++) {
401+ batch->pos [i] = pos0 + i;
402+ batch->n_seq_id [i] = 1 ;
402403 batch->seq_id [i][0 ] = seq_id;
403404 }
404405 return batch;
0 commit comments