22#include " llava.h"
33
44#include " llama.h"
5- #include " common.h"
65
76#include < algorithm>
87#include < cerrno>
@@ -402,6 +401,38 @@ bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, co
402401 return true ;
403402}
404403
404+ struct llava_embd_batch {
405+ std::vector<llama_pos> pos;
406+ std::vector<int32_t > n_seq_id;
407+ std::array <llama_seq_id, 1 > seq_id_0;
408+ std::vector<llama_seq_id *> seq_ids;
409+ std::vector<int8_t > logits;
410+ llama_batch batch;
411+ llava_embd_batch (float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
412+ pos .resize (n_tokens);
413+ n_seq_id.resize (n_tokens);
414+ seq_ids .resize (n_tokens + 1 );
415+ logits .resize (n_tokens);
416+ seq_id_0[0 ] = seq_id;
417+ seq_ids [n_tokens] = nullptr ;
418+ batch = {
419+ /* n_tokens =*/ n_tokens,
420+ /* tokens =*/ nullptr ,
421+ /* embd =*/ embd,
422+ /* pos =*/ pos.data (),
423+ /* n_seq_id =*/ n_seq_id.data (),
424+ /* seq_id =*/ seq_ids.data (),
425+ /* logits =*/ logits.data (),
426+ };
427+ for (int i = 0 ; i < n_tokens; i++) {
428+ batch.pos [i] = pos_0 + i;
429+ batch.n_seq_id [i] = 1 ;
430+ batch.seq_id [i] = seq_id_0.data ();
431+ batch.logits [i] = false ;
432+ }
433+ }
434+ };
435+
405436bool llava_eval_image_embed (llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) {
406437 int n_embd = llama_n_embd (llama_get_model (ctx_llama));
407438
@@ -411,8 +442,8 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
411442 n_eval = n_batch;
412443 }
413444 float * embd = image_embed->embed +i*n_embd;
414- llama_batch batch = llama_batch_get_one (embd, n_eval, *n_past, 0 );
415- if (llama_decode (ctx_llama, batch)) {
445+ llava_embd_batch llava_batch = llava_embd_batch (embd, n_eval, *n_past, 0 );
446+ if (llama_decode (ctx_llama, llava_batch. batch )) {
416447 LOG_ERR (" %s : failed to eval\n " , __func__);
417448 return false ;
418449 }
0 commit comments