@@ -1514,6 +1514,66 @@ static void load_grammar(const std::string & gammarstr)
15141514 }
15151515}
15161516
1517+ struct kcpp_embd_batch { // duplcated from llava_embd_batch
1518+ std::vector<int32_t > pos;
1519+ std::vector<int32_t > n_seq_id;
1520+ std::vector<int32_t > seq_id_0;
1521+ std::vector<int32_t *> seq_ids;
1522+ std::vector<int8_t > logits;
1523+ llama_batch batch;
1524+ kcpp_embd_batch (float * embd, int32_t n_tokens, int32_t npast) {
1525+ int32_t seq_id = 0 ;
1526+ pos.resize (n_tokens);
1527+ n_seq_id.resize (n_tokens);
1528+ seq_ids.resize (n_tokens + 1 );
1529+ logits.resize (n_tokens);
1530+ seq_id_0.resize (1 );
1531+ seq_id_0[0 ] = seq_id;
1532+ seq_ids [n_tokens] = nullptr ;
1533+ batch = {
1534+ /* n_tokens =*/ n_tokens,
1535+ /* tokens =*/ nullptr ,
1536+ /* embd =*/ embd,
1537+ /* pos =*/ pos.data (),
1538+ /* n_seq_id =*/ n_seq_id.data (),
1539+ /* seq_id =*/ seq_ids.data (),
1540+ /* logits =*/ logits.data (),
1541+ };
1542+ for (int i = 0 ; i < n_tokens; i++) {
1543+ batch.pos [i] = npast + i;
1544+ batch.n_seq_id [i] = 1 ;
1545+ batch.seq_id [i] = seq_id_0.data ();
1546+ batch.logits [i] = false ;
1547+ }
1548+ }
1549+ kcpp_embd_batch (std::vector<llama_token> & tokens, int32_t npast) {
1550+ int32_t seq_id = 0 ;
1551+ int32_t n_tokens = tokens.size ();
1552+ pos.resize (n_tokens);
1553+ n_seq_id.resize (n_tokens);
1554+ seq_ids.resize (n_tokens + 1 );
1555+ logits.resize (n_tokens);
1556+ seq_id_0.resize (1 );
1557+ seq_id_0[0 ] = seq_id;
1558+ seq_ids [n_tokens] = nullptr ;
1559+ batch = {
1560+ /* n_tokens =*/ n_tokens,
1561+ /* tokens =*/ tokens.data (),
1562+ /* embd =*/ nullptr ,
1563+ /* pos =*/ pos.data (),
1564+ /* n_seq_id =*/ n_seq_id.data (),
1565+ /* seq_id =*/ seq_ids.data (),
1566+ /* logits =*/ logits.data (),
1567+ };
1568+ for (int i = 0 ; i < n_tokens; i++) {
1569+ batch.pos [i] = npast + i;
1570+ batch.n_seq_id [i] = 1 ;
1571+ batch.seq_id [i] = seq_id_0.data ();
1572+ batch.logits [i] = false ;
1573+ }
1574+ batch.logits [n_tokens - 1 ] = true ;
1575+ }
1576+ };
15171577static bool kcpp_eval_image (llama_context * ctx_llama, float * img_embd, int num_img_tokens, int n_batch, int * n_past) {
15181578 int n_embd = llama_n_embd (llama_get_model (ctx_llama));
15191579
@@ -1522,8 +1582,9 @@ static bool kcpp_eval_image(llama_context * ctx_llama, float * img_embd, int num
15221582 if (n_eval > n_batch) {
15231583 n_eval = n_batch;
15241584 }
1525- llama_batch batch = {int32_t (n_eval), nullptr , (img_embd+i*n_embd), nullptr , nullptr , nullptr , nullptr ,};
1526- if (llama_decode (ctx_llama, batch)) {
1585+ float * embd = img_embd+i*n_embd;
1586+ kcpp_embd_batch llava_batch = kcpp_embd_batch (embd, n_eval, *n_past);
1587+ if (llama_decode (ctx_llama, llava_batch.batch )) {
15271588 fprintf (stderr, " \n %s : failed to eval image\n " , __func__);
15281589 return false ;
15291590 }
@@ -3108,7 +3169,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
31083169 }
31093170 else if (file_format == FileFormat::GGUF_GENERIC)
31103171 {
3111- evalres = (llama_decode (llama_ctx_v4, llama_batch_get_one (embd.data (), embdsize))==0 );
3172+ kcpp_embd_batch batch = kcpp_embd_batch (embd, n_past);
3173+ evalres = (llama_decode (llama_ctx_v4, batch.batch )==0 );
31123174 }
31133175 else if (file_format==FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
31143176 {
@@ -3485,7 +3547,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
34853547 if (i>0 && sepsize>0 )
34863548 {
34873549 // add a separator between each image
3488- auto evr = llama_decode (llama_ctx_v4, llama_batch_get_one (llava_sep.data (), sepsize));
3550+ kcpp_embd_batch batch = kcpp_embd_batch (embd, n_past);
3551+ auto evr = llama_decode (llama_ctx_v4, batch.batch );
34893552 if (evr!=0 )
34903553 {
34913554 printf (" \n Error when appending llava separator: %d\n " ,evr);
0 commit comments