Skip to content

Commit 712149f

Browse files
committed
Merge follow-up
except for kcpp_eval_image error, I'm unable to fix the args.
1 parent 63485bd commit 712149f

File tree

1 file changed

+91
-20
lines changed

1 file changed

+91
-20
lines changed

gpttype_adapter.cpp

Lines changed: 91 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1485,16 +1485,50 @@ static void load_grammar(const std::string & gammarstr)
14851485
}
14861486
}
14871487

1488-
static bool kcpp_eval_image(llama_context * ctx_llama, float * img_embd, int num_img_tokens, int n_batch, int * n_past) {
1488+
struct llava_embd_batch {
1489+
std::vector<llama_pos> pos;
1490+
std::vector<int32_t> n_seq_id;
1491+
std::vector<llama_seq_id> seq_id_0;
1492+
std::vector<llama_seq_id *> seq_ids;
1493+
std::vector<int8_t> logits;
1494+
llama_batch batch;
1495+
llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
1496+
pos .resize(n_tokens);
1497+
n_seq_id.resize(n_tokens);
1498+
seq_ids .resize(n_tokens + 1);
1499+
logits .resize(n_tokens);
1500+
seq_id_0.resize(1);
1501+
seq_id_0[0] = seq_id;
1502+
seq_ids [n_tokens] = nullptr;
1503+
batch = {
1504+
/*n_tokens =*/ n_tokens,
1505+
/*tokens =*/ nullptr,
1506+
/*embd =*/ embd,
1507+
/*pos =*/ pos.data(),
1508+
/*n_seq_id =*/ n_seq_id.data(),
1509+
/*seq_id =*/ seq_ids.data(),
1510+
/*logits =*/ logits.data(),
1511+
};
1512+
for (int i = 0; i < n_tokens; i++) {
1513+
batch.pos [i] = pos_0 + i;
1514+
batch.n_seq_id[i] = 1;
1515+
batch.seq_id [i] = seq_id_0.data();
1516+
batch.logits [i] = false;
1517+
}
1518+
}
1519+
};
1520+
1521+
static bool kcpp_eval_image(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) {
14891522
int n_embd = llama_n_embd(llama_get_model(ctx_llama));
14901523

1491-
for (int i = 0; i < num_img_tokens; i += n_batch) {
1492-
int n_eval = num_img_tokens - i;
1524+
for (int i = 0; i < image_embed->n_image_pos; i += n_batch) {
1525+
int n_eval = image_embed->n_image_pos - i;
14931526
if (n_eval > n_batch) {
14941527
n_eval = n_batch;
14951528
}
1496-
llama_batch batch = {int32_t(n_eval), nullptr, (img_embd+i*n_embd), nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, };
1497-
if (llama_decode(ctx_llama, batch)) {
1529+
float * embd = image_embed->embed+i*n_embd;
1530+
llava_embd_batch llava_batch = llava_embd_batch(embd, n_eval, *n_past, 0);
1531+
if (llama_decode(ctx_llama, llava_batch.batch)) {
14981532
fprintf(stderr, "\n%s : failed to eval image\n", __func__);
14991533
return false;
15001534
}
@@ -1503,6 +1537,43 @@ static bool kcpp_eval_image(llama_context * ctx_llama, float * img_embd, int num
15031537
return true;
15041538
}
15051539

1540+
// static bool kcpp_eval_image(llama_context * ctx_llama, float * img_embd, int num_img_tokens, int n_batch, int * n_past) {
1541+
// int n_embd = llama_n_embd(llama_get_model(ctx_llama));
1542+
1543+
// for (int i = 0; i < num_img_tokens; i += n_batch) {
1544+
// int n_eval = num_img_tokens - i;
1545+
// if (n_eval > n_batch) {
1546+
// n_eval = n_batch;
1547+
// }
1548+
// llama_batch batch = {int32_t(n_eval), nullptr, (img_embd+i*n_embd), nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, };
1549+
// if (llama_decode(ctx_llama, batch)) {
1550+
// fprintf(stderr, "\n%s : failed to eval image\n", __func__);
1551+
// return false;
1552+
// }
1553+
// *n_past += n_eval;
1554+
// }
1555+
// return true;
1556+
// }
1557+
1558+
// bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) {
1559+
// int n_embd = llama_n_embd(llama_get_model(ctx_llama));
1560+
1561+
// for (int i = 0; i < image_embed->n_image_pos; i += n_batch) {
1562+
// int n_eval = image_embed->n_image_pos - i;
1563+
// if (n_eval > n_batch) {
1564+
// n_eval = n_batch;
1565+
// }
1566+
// float * embd = image_embed->embed+i*n_embd;
1567+
// llava_embd_batch llava_batch = llava_embd_batch(embd, n_eval, *n_past, 0);
1568+
// if (llama_decode(ctx_llama, llava_batch.batch)) {
1569+
// LOG_ERR("%s : failed to eval\n", __func__);
1570+
// return false;
1571+
// }
1572+
// *n_past += n_eval;
1573+
// }
1574+
// return true;
1575+
// }
1576+
15061577
//given an old GGUF context and a new context that has some middle portion removed,
15071578
//find and remove the middle portion from the old context from the KV. Does not fast forward after this destructive action
15081579
void PurgeMissingTokens(llama_context * ctx, std::vector<int> &current_context_tokens, std::vector<int> &new_context_tokens, const int genamt, const int nctx)
@@ -2119,7 +2190,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
21192190
//determine mem per token
21202191
std::vector<int> tmp = {1, 2, 3, 4};
21212192
llama_kv_cache_clear(llama_ctx_v4);
2122-
auto er = llama_decode(llama_ctx_v4, llama_batch_get_one(tmp.data(), tmp.size(), 0, 0));
2193+
auto er = llama_decode(llama_ctx_v4, llama_batch_get_one(tmp.data(), tmp.size()));
21232194
if(er!=0)
21242195
{
21252196
printf("\nLLAMA EVAL returned nonzero: %d\n",er);
@@ -3182,7 +3253,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
31823253
}
31833254
else if(file_format == FileFormat::GGUF_GENERIC)
31843255
{
3185-
evalres = (llama_decode(llama_ctx_v4, llama_batch_get_one(embd.data(), embdsize, n_past, 0))==0);
3256+
evalres = (llama_decode(llama_ctx_v4, llama_batch_get_one(embd.data(), embdsize))==0);
31863257
}
31873258
else if(file_format==FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
31883259
{
@@ -3563,7 +3634,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
35633634
if(i>0 && sepsize>0)
35643635
{
35653636
//add a separator between each image
3566-
auto evr = llama_decode(llama_ctx_v4, llama_batch_get_one(llava_sep.data(), sepsize, n_past, 0));
3637+
auto evr = llama_decode(llama_ctx_v4, llama_batch_get_one(llava_sep.data(), sepsize));
35673638
if(evr!=0)
35683639
{
35693640
printf("\nError when appending llava separator: %d\n",evr);
@@ -3580,18 +3651,18 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
35803651
{
35813652
printf("\rProcessing LLaVa Embedding %d (%d tokens)",(i+1), llava_images[i].clp_image_tokens);
35823653
}
3583-
bool err = kcpp_eval_image(llama_ctx_v4,llava_images[i].clp_img_embd,llava_images[i].clp_image_tokens,kcpp_data->n_batch,&n_past);
3584-
llavatokensevaled += llava_images[i].clp_image_tokens;
3585-
if(!err)
3586-
{
3587-
llava_composite_image_signature = ""; //force invalidate
3588-
fprintf(stderr, "\nFailed to eval llava image at %d!\n",n_past);
3589-
output.text = nullptr;
3590-
output.status = 0;
3591-
output.stopreason = stop_reason::INVALID;
3592-
generation_finished = true;
3593-
return output;
3594-
}
3654+
// bool err = kcpp_eval_image(llama_ctx_v4,llava_images[i].clp_img_embd,llava_images[i].clp_image_tokens,kcpp_data->n_batch,&n_past);
3655+
// llavatokensevaled += llava_images[i].clp_image_tokens;
3656+
// if(!err)
3657+
// {
3658+
// llava_composite_image_signature = ""; //force invalidate
3659+
// fprintf(stderr, "\nFailed to eval llava image at %d!\n",n_past);
3660+
// output.text = nullptr;
3661+
// output.status = 0;
3662+
// output.stopreason = stop_reason::INVALID;
3663+
// generation_finished = true;
3664+
// return output;
3665+
// }
35953666
}
35963667
if(llavatokenscounted!=llavatokensevaled)
35973668
{

0 commit comments

Comments
 (0)