Skip to content

Commit 4b96c3b

Browse files
committed
try new batch api (not actually batching)
1 parent 8a7d53d commit 4b96c3b

File tree

1 file changed

+67
-4
lines changed

1 file changed

+67
-4
lines changed

gpttype_adapter.cpp

Lines changed: 67 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1514,6 +1514,66 @@ static void load_grammar(const std::string & gammarstr)
15141514
}
15151515
}
15161516

1517+
struct kcpp_embd_batch { //duplcated from llava_embd_batch
1518+
std::vector<int32_t> pos;
1519+
std::vector<int32_t> n_seq_id;
1520+
std::vector<int32_t> seq_id_0;
1521+
std::vector<int32_t *> seq_ids;
1522+
std::vector<int8_t> logits;
1523+
llama_batch batch;
1524+
kcpp_embd_batch(float * embd, int32_t n_tokens, int32_t npast) {
1525+
int32_t seq_id = 0;
1526+
pos.resize(n_tokens);
1527+
n_seq_id.resize(n_tokens);
1528+
seq_ids.resize(n_tokens + 1);
1529+
logits.resize(n_tokens);
1530+
seq_id_0.resize(1);
1531+
seq_id_0[0] = seq_id;
1532+
seq_ids [n_tokens] = nullptr;
1533+
batch = {
1534+
/*n_tokens =*/ n_tokens,
1535+
/*tokens =*/ nullptr,
1536+
/*embd =*/ embd,
1537+
/*pos =*/ pos.data(),
1538+
/*n_seq_id =*/ n_seq_id.data(),
1539+
/*seq_id =*/ seq_ids.data(),
1540+
/*logits =*/ logits.data(),
1541+
};
1542+
for (int i = 0; i < n_tokens; i++) {
1543+
batch.pos [i] = npast + i;
1544+
batch.n_seq_id[i] = 1;
1545+
batch.seq_id [i] = seq_id_0.data();
1546+
batch.logits [i] = false;
1547+
}
1548+
}
1549+
kcpp_embd_batch(std::vector<llama_token> & tokens, int32_t npast) {
1550+
int32_t seq_id = 0;
1551+
int32_t n_tokens = tokens.size();
1552+
pos.resize(n_tokens);
1553+
n_seq_id.resize(n_tokens);
1554+
seq_ids.resize(n_tokens + 1);
1555+
logits.resize(n_tokens);
1556+
seq_id_0.resize(1);
1557+
seq_id_0[0] = seq_id;
1558+
seq_ids [n_tokens] = nullptr;
1559+
batch = {
1560+
/*n_tokens =*/ n_tokens,
1561+
/*tokens =*/ tokens.data(),
1562+
/*embd =*/ nullptr,
1563+
/*pos =*/ pos.data(),
1564+
/*n_seq_id =*/ n_seq_id.data(),
1565+
/*seq_id =*/ seq_ids.data(),
1566+
/*logits =*/ logits.data(),
1567+
};
1568+
for (int i = 0; i < n_tokens; i++) {
1569+
batch.pos [i] = npast + i;
1570+
batch.n_seq_id[i] = 1;
1571+
batch.seq_id [i] = seq_id_0.data();
1572+
batch.logits [i] = false;
1573+
}
1574+
batch.logits[n_tokens - 1] = true;
1575+
}
1576+
};
15171577
static bool kcpp_eval_image(llama_context * ctx_llama, float * img_embd, int num_img_tokens, int n_batch, int * n_past) {
15181578
int n_embd = llama_n_embd(llama_get_model(ctx_llama));
15191579

@@ -1522,8 +1582,9 @@ static bool kcpp_eval_image(llama_context * ctx_llama, float * img_embd, int num
15221582
if (n_eval > n_batch) {
15231583
n_eval = n_batch;
15241584
}
1525-
llama_batch batch = {int32_t(n_eval), nullptr, (img_embd+i*n_embd), nullptr, nullptr, nullptr, nullptr,};
1526-
if (llama_decode(ctx_llama, batch)) {
1585+
float * embd = img_embd+i*n_embd;
1586+
kcpp_embd_batch llava_batch = kcpp_embd_batch(embd, n_eval, *n_past);
1587+
if (llama_decode(ctx_llama, llava_batch.batch)) {
15271588
fprintf(stderr, "\n%s : failed to eval image\n", __func__);
15281589
return false;
15291590
}
@@ -3108,7 +3169,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
31083169
}
31093170
else if(file_format == FileFormat::GGUF_GENERIC)
31103171
{
3111-
evalres = (llama_decode(llama_ctx_v4, llama_batch_get_one(embd.data(), embdsize))==0);
3172+
kcpp_embd_batch batch = kcpp_embd_batch(embd, n_past);
3173+
evalres = (llama_decode(llama_ctx_v4, batch.batch)==0);
31123174
}
31133175
else if(file_format==FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
31143176
{
@@ -3485,7 +3547,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
34853547
if(i>0 && sepsize>0)
34863548
{
34873549
//add a separator between each image
3488-
auto evr = llama_decode(llama_ctx_v4, llama_batch_get_one(llava_sep.data(), sepsize));
3550+
kcpp_embd_batch batch = kcpp_embd_batch(embd, n_past);
3551+
auto evr = llama_decode(llama_ctx_v4, batch.batch);
34893552
if(evr!=0)
34903553
{
34913554
printf("\nError when appending llava separator: %d\n",evr);

0 commit comments

Comments
 (0)