@@ -1485,16 +1485,50 @@ static void load_grammar(const std::string & gammarstr)
14851485 }
14861486}
14871487
1488- static bool kcpp_eval_image (llama_context * ctx_llama, float * img_embd, int num_img_tokens, int n_batch, int * n_past) {
1488+ struct llava_embd_batch {
1489+ std::vector<llama_pos> pos;
1490+ std::vector<int32_t > n_seq_id;
1491+ std::vector<llama_seq_id> seq_id_0;
1492+ std::vector<llama_seq_id *> seq_ids;
1493+ std::vector<int8_t > logits;
1494+ llama_batch batch;
1495+ llava_embd_batch (float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
1496+ pos .resize (n_tokens);
1497+ n_seq_id.resize (n_tokens);
1498+ seq_ids .resize (n_tokens + 1 );
1499+ logits .resize (n_tokens);
1500+ seq_id_0.resize (1 );
1501+ seq_id_0[0 ] = seq_id;
1502+ seq_ids [n_tokens] = nullptr ;
1503+ batch = {
1504+ /* n_tokens =*/ n_tokens,
1505+ /* tokens =*/ nullptr ,
1506+ /* embd =*/ embd,
1507+ /* pos =*/ pos.data (),
1508+ /* n_seq_id =*/ n_seq_id.data (),
1509+ /* seq_id =*/ seq_ids.data (),
1510+ /* logits =*/ logits.data (),
1511+ };
1512+ for (int i = 0 ; i < n_tokens; i++) {
1513+ batch.pos [i] = pos_0 + i;
1514+ batch.n_seq_id [i] = 1 ;
1515+ batch.seq_id [i] = seq_id_0.data ();
1516+ batch.logits [i] = false ;
1517+ }
1518+ }
1519+ };
1520+
1521+ static bool kcpp_eval_image (llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) {
14891522 int n_embd = llama_n_embd (llama_get_model (ctx_llama));
14901523
1491- for (int i = 0 ; i < num_img_tokens ; i += n_batch) {
1492- int n_eval = num_img_tokens - i;
1524+ for (int i = 0 ; i < image_embed-> n_image_pos ; i += n_batch) {
1525+ int n_eval = image_embed-> n_image_pos - i;
14931526 if (n_eval > n_batch) {
14941527 n_eval = n_batch;
14951528 }
1496- llama_batch batch = {int32_t (n_eval), nullptr , (img_embd+i*n_embd), nullptr , nullptr , nullptr , nullptr , *n_past, 1 , 0 , };
1497- if (llama_decode (ctx_llama, batch)) {
1529+ float * embd = image_embed->embed +i*n_embd;
1530+ llava_embd_batch llava_batch = llava_embd_batch (embd, n_eval, *n_past, 0 );
1531+ if (llama_decode (ctx_llama, llava_batch.batch )) {
14981532 fprintf (stderr, " \n %s : failed to eval image\n " , __func__);
14991533 return false ;
15001534 }
@@ -1503,6 +1537,43 @@ static bool kcpp_eval_image(llama_context * ctx_llama, float * img_embd, int num
15031537 return true ;
15041538}
15051539
1540+ // static bool kcpp_eval_image(llama_context * ctx_llama, float * img_embd, int num_img_tokens, int n_batch, int * n_past) {
1541+ // int n_embd = llama_n_embd(llama_get_model(ctx_llama));
1542+
1543+ // for (int i = 0; i < num_img_tokens; i += n_batch) {
1544+ // int n_eval = num_img_tokens - i;
1545+ // if (n_eval > n_batch) {
1546+ // n_eval = n_batch;
1547+ // }
1548+ // llama_batch batch = {int32_t(n_eval), nullptr, (img_embd+i*n_embd), nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, };
1549+ // if (llama_decode(ctx_llama, batch)) {
1550+ // fprintf(stderr, "\n%s : failed to eval image\n", __func__);
1551+ // return false;
1552+ // }
1553+ // *n_past += n_eval;
1554+ // }
1555+ // return true;
1556+ // }
1557+
1558+ // bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) {
1559+ // int n_embd = llama_n_embd(llama_get_model(ctx_llama));
1560+
1561+ // for (int i = 0; i < image_embed->n_image_pos; i += n_batch) {
1562+ // int n_eval = image_embed->n_image_pos - i;
1563+ // if (n_eval > n_batch) {
1564+ // n_eval = n_batch;
1565+ // }
1566+ // float * embd = image_embed->embed+i*n_embd;
1567+ // llava_embd_batch llava_batch = llava_embd_batch(embd, n_eval, *n_past, 0);
1568+ // if (llama_decode(ctx_llama, llava_batch.batch)) {
1569+ // LOG_ERR("%s : failed to eval\n", __func__);
1570+ // return false;
1571+ // }
1572+ // *n_past += n_eval;
1573+ // }
1574+ // return true;
1575+ // }
1576+
15061577// given an old GGUF context and a new context that has some middle portion removed,
15071578// find and remove the middle portion from the old context from the KV. Does not fast forward after this destructive action
15081579void PurgeMissingTokens (llama_context * ctx, std::vector<int > ¤t_context_tokens, std::vector<int > &new_context_tokens, const int genamt, const int nctx)
@@ -2119,7 +2190,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
21192190 // determine mem per token
21202191 std::vector<int > tmp = {1 , 2 , 3 , 4 };
21212192 llama_kv_cache_clear (llama_ctx_v4);
2122- auto er = llama_decode (llama_ctx_v4, llama_batch_get_one (tmp.data (), tmp.size (), 0 , 0 ));
2193+ auto er = llama_decode (llama_ctx_v4, llama_batch_get_one (tmp.data (), tmp.size ()));
21232194 if (er!=0 )
21242195 {
21252196 printf (" \n LLAMA EVAL returned nonzero: %d\n " ,er);
@@ -3182,7 +3253,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
31823253 }
31833254 else if (file_format == FileFormat::GGUF_GENERIC)
31843255 {
3185- evalres = (llama_decode (llama_ctx_v4, llama_batch_get_one (embd.data (), embdsize, n_past, 0 ))==0 );
3256+ evalres = (llama_decode (llama_ctx_v4, llama_batch_get_one (embd.data (), embdsize))==0 );
31863257 }
31873258 else if (file_format==FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
31883259 {
@@ -3563,7 +3634,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
35633634 if (i>0 && sepsize>0 )
35643635 {
35653636 // add a separator between each image
3566- auto evr = llama_decode (llama_ctx_v4, llama_batch_get_one (llava_sep.data (), sepsize, n_past, 0 ));
3637+ auto evr = llama_decode (llama_ctx_v4, llama_batch_get_one (llava_sep.data (), sepsize));
35673638 if (evr!=0 )
35683639 {
35693640 printf (" \n Error when appending llava separator: %d\n " ,evr);
@@ -3580,18 +3651,18 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
35803651 {
35813652 printf (" \r Processing LLaVa Embedding %d (%d tokens)" ,(i+1 ), llava_images[i].clp_image_tokens );
35823653 }
3583- bool err = kcpp_eval_image (llama_ctx_v4,llava_images[i].clp_img_embd ,llava_images[i].clp_image_tokens ,kcpp_data->n_batch ,&n_past);
3584- llavatokensevaled += llava_images[i].clp_image_tokens ;
3585- if (!err)
3586- {
3587- llava_composite_image_signature = " " ; // force invalidate
3588- fprintf (stderr, " \n Failed to eval llava image at %d!\n " ,n_past);
3589- output.text = nullptr ;
3590- output.status = 0 ;
3591- output.stopreason = stop_reason::INVALID;
3592- generation_finished = true ;
3593- return output;
3594- }
3654+ // bool err = kcpp_eval_image(llama_ctx_v4,llava_images[i].clp_img_embd,llava_images[i].clp_image_tokens,kcpp_data->n_batch,&n_past);
3655+ // llavatokensevaled += llava_images[i].clp_image_tokens;
3656+ // if(!err)
3657+ // {
3658+ // llava_composite_image_signature = ""; //force invalidate
3659+ // fprintf(stderr, "\nFailed to eval llava image at %d!\n",n_past);
3660+ // output.text = nullptr;
3661+ // output.status = 0;
3662+ // output.stopreason = stop_reason::INVALID;
3663+ // generation_finished = true;
3664+ // return output;
3665+ // }
35953666 }
35963667 if (llavatokenscounted!=llavatokensevaled)
35973668 {
0 commit comments