Skip to content

Commit 47086fa

Browse files
committed
apply to the rest
1 parent 4aabf4e commit 47086fa

File tree

18 files changed

+243
-324
lines changed

18 files changed

+243
-324
lines changed

common/common.cpp

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -582,43 +582,6 @@ std::string string_from(const struct llama_context * ctx, const std::vector<llam
582582
return buf.str();
583583
}
584584

585-
/*
586-
std::string string_from(const struct llama_context * ctx, const struct llama_batch & batch) {
587-
std::stringstream buf;
588-
589-
buf << "[ ";
590-
591-
bool first = true;
592-
for (int i = 0; i < batch.n_tokens; ++i) {
593-
if (!first) {
594-
buf << ", ";
595-
} else {
596-
first = false;
597-
}
598-
599-
auto detokenized = common_token_to_piece(ctx, batch.token[i]);
600-
601-
detokenized.erase(
602-
std::remove_if(
603-
detokenized.begin(),
604-
detokenized.end(),
605-
[](const unsigned char c) { return !std::isprint(c); }),
606-
detokenized.end());
607-
608-
buf << "\n" << std::to_string(i)
609-
<< ", token '" << detokenized << "'"
610-
<< ", pos " << std::to_string(batch.pos[i])
611-
<< ", n_seq_id " << std::to_string(batch.n_seq_id[i])
612-
<< ", seq_id " << std::to_string(batch.seq_id[i][0])
613-
<< ", logits " << std::to_string(batch.logits[i]);
614-
}
615-
616-
buf << " ]";
617-
618-
return buf.str();
619-
}
620-
*/
621-
622585
void string_process_escapes(std::string & input) {
623586
std::size_t input_len = input.length();
624587
std::size_t output_idx = 0;

common/common.h

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,6 @@ void string_process_escapes(std::string & input);
516516
std::string string_from(bool value);
517517
std::string string_from(const std::vector<int> & values);
518518
std::string string_from(const struct llama_context * ctx, const std::vector<llama_token> & tokens);
519-
std::string string_from(const struct llama_context * ctx, const struct llama_batch & batch);
520519

521520
//
522521
// Filesystem utils
@@ -587,10 +586,10 @@ struct common_batch {
587586
llama_batch_ext_ptr batch;
588587
struct batch_token {
589588
llama_token token;
590-
llama_seq_id seq_id;
591589
bool logits;
592590
};
593591
std::vector<batch_token> tokens;
592+
int n_outputs = 0;
594593
common_batch() = default;
595594
common_batch(int32_t n_tokens, int32_t n_seq_max) {
596595
batch.reset(llama_batch_ext_init(n_tokens, n_seq_max));
@@ -602,7 +601,17 @@ struct common_batch {
602601
}
603602
void add_text(llama_token token, llama_pos pos, llama_seq_id seq_id, bool logits) {
604603
llama_batch_ext_add_text(batch.get(), token, pos, &seq_id, 1, logits);
605-
tokens.push_back({token, seq_id, logits});
604+
tokens.push_back({token, logits});
605+
if (logits) {
606+
n_outputs++;
607+
}
608+
}
609+
void add_text(llama_token token, llama_pos pos, std::vector<llama_seq_id> seq_ids, bool logits) {
610+
llama_batch_ext_add_text(batch.get(), token, pos, seq_ids.data(), seq_ids.size(), logits);
611+
tokens.push_back({token, logits});
612+
if (logits) {
613+
n_outputs++;
614+
}
606615
}
607616
void set_logits_last() {
608617
if (!tokens.empty()) {
@@ -622,6 +631,9 @@ struct common_batch {
622631
view.tokens.reserve(n_tokens);
623632
for (int32_t i = 0; i < n_tokens; i++) {
624633
view.tokens.push_back(tokens[offset + i]);
634+
if (tokens[offset + i].logits) {
635+
view.n_outputs++;
636+
}
625637
}
626638
return view;
627639
}

examples/llava/gemma3-cli.cpp

Lines changed: 14 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "clip.h"
66
#include "stb_image.h"
77
#include "llama.h"
8+
#include "llama-cpp.h"
89
#include "ggml.h"
910
#include "console.h"
1011

@@ -63,7 +64,7 @@ struct gemma3_context {
6364
llama_model * model;
6465
llama_context * lctx;
6566
const llama_vocab * vocab;
66-
llama_batch batch;
67+
llama_batch_ext_ptr batch;
6768

6869
int n_threads = 1;
6970
llama_pos n_past = 0;
@@ -73,7 +74,7 @@ struct gemma3_context {
7374
lctx = llama_init.context.get();
7475
vocab = llama_model_get_vocab(model);
7576
n_threads = params.cpuparams.n_threads;
76-
batch = llama_batch_init(params.n_batch, 0, 1);
77+
batch.reset(llama_batch_ext_init(params.n_batch, 1));
7778
init_clip_model(params);
7879
}
7980

@@ -87,50 +88,18 @@ struct gemma3_context {
8788
}
8889
};
8990

90-
struct decode_embd_batch {
91-
std::vector<llama_pos> pos;
92-
std::vector<int32_t> n_seq_id;
93-
std::vector<llama_seq_id> seq_id_0;
94-
std::vector<llama_seq_id *> seq_ids;
95-
std::vector<int8_t> logits;
96-
llama_batch batch;
97-
decode_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
98-
pos .resize(n_tokens);
99-
n_seq_id.resize(n_tokens);
100-
seq_ids .resize(n_tokens + 1);
101-
logits .resize(n_tokens);
102-
seq_id_0.resize(1);
103-
seq_id_0[0] = seq_id;
104-
seq_ids [n_tokens] = nullptr;
105-
batch = {
106-
/*n_tokens =*/ n_tokens,
107-
/*tokens =*/ nullptr,
108-
/*embd =*/ embd,
109-
/*pos =*/ pos.data(),
110-
/*n_seq_id =*/ n_seq_id.data(),
111-
/*seq_id =*/ seq_ids.data(),
112-
/*logits =*/ logits.data(),
113-
};
114-
for (int i = 0; i < n_tokens; i++) {
115-
batch.pos [i] = pos_0 + i;
116-
batch.n_seq_id[i] = 1;
117-
batch.seq_id [i] = seq_id_0.data();
118-
batch.logits [i] = false;
119-
}
120-
}
121-
};
122-
12391
static int eval_text(gemma3_context & ctx, std::string input, bool logits_last = false) {
12492
llama_tokens tokens = common_tokenize(ctx.lctx, input, false, true);
125-
common_batch_clear(ctx.batch);
93+
llama_batch_ext_clear(ctx.batch.get());
12694
for (llama_token & t : tokens) {
127-
common_batch_add(ctx.batch, t, ctx.n_past++, {0}, false);
95+
llama_seq_id seq_id = 0;
96+
llama_batch_ext_add_text(ctx.batch.get(), t, 0, &seq_id, 1, false);
12897
}
12998
if (logits_last) {
130-
ctx.batch.logits[ctx.batch.n_tokens - 1] = true;
99+
llama_batch_ext_set_output_last(ctx.batch.get());
131100
}
132101
// LOG("eval_text (n_tokens = %d): %s\n", (int)tokens.size(), input.c_str());
133-
if (llama_decode(ctx.lctx, ctx.batch)) {
102+
if (llama_decode_ext(ctx.lctx, ctx.batch.get())) {
134103
LOG_ERR("Failed to decode text\n");
135104
return 1;
136105
}
@@ -179,8 +148,8 @@ static int eval_image(gemma3_context & ctx, std::string & fname) {
179148
int64_t t1 = ggml_time_ms();
180149
eval_text(ctx, "<start_of_image>");
181150
llama_set_causal_attn(ctx.lctx, false);
182-
decode_embd_batch batch_img(image_embd_v.data(), n_tokens, ctx.n_past, 0);
183-
if (llama_decode(ctx.lctx, batch_img.batch)) {
151+
llama_batch_ext_ptr batch_img(llama_batch_ext_init_from_embd(image_embd_v.data(), n_tokens, ctx.n_past, 0));
152+
if (llama_decode_ext(ctx.lctx, batch_img.get())) {
184153
LOG_ERR("failed to decode image\n");
185154
return 1;
186155
}
@@ -210,9 +179,10 @@ static int generate_response(gemma3_context & ctx, common_sampler * smpl, int n_
210179
fflush(stdout);
211180

212181
// eval the token
213-
common_batch_clear(ctx.batch);
214-
common_batch_add(ctx.batch, token_id, ctx.n_past++, {0}, true);
215-
if (llama_decode(ctx.lctx, ctx.batch)) {
182+
llama_batch_ext_clear(ctx.batch.get());
183+
llama_seq_id seq_id = 0;
184+
llama_batch_ext_add_text(ctx.batch.get(), token_id, ctx.n_past++, &seq_id, 1, true);
185+
if (llama_decode_ext(ctx.lctx, ctx.batch.get())) {
216186
LOG_ERR("failed to decode token\n");
217187
return 1;
218188
}

examples/llava/llava.cpp

Lines changed: 3 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "llava.h"
33

44
#include "llama.h"
5+
#include "llama-cpp.h"
56

67
#include <algorithm>
78
#include <cerrno>
@@ -438,39 +439,6 @@ bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, co
438439
return true;
439440
}
440441

441-
struct llava_embd_batch {
442-
std::vector<llama_pos> pos;
443-
std::vector<int32_t> n_seq_id;
444-
std::vector<llama_seq_id> seq_id_0;
445-
std::vector<llama_seq_id *> seq_ids;
446-
std::vector<int8_t> logits;
447-
llama_batch batch;
448-
llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
449-
pos .resize(n_tokens);
450-
n_seq_id.resize(n_tokens);
451-
seq_ids .resize(n_tokens + 1);
452-
logits .resize(n_tokens);
453-
seq_id_0.resize(1);
454-
seq_id_0[0] = seq_id;
455-
seq_ids [n_tokens] = nullptr;
456-
batch = {
457-
/*n_tokens =*/ n_tokens,
458-
/*tokens =*/ nullptr,
459-
/*embd =*/ embd,
460-
/*pos =*/ pos.data(),
461-
/*n_seq_id =*/ n_seq_id.data(),
462-
/*seq_id =*/ seq_ids.data(),
463-
/*logits =*/ logits.data(),
464-
};
465-
for (int i = 0; i < n_tokens; i++) {
466-
batch.pos [i] = pos_0 + i;
467-
batch.n_seq_id[i] = 1;
468-
batch.seq_id [i] = seq_id_0.data();
469-
batch.logits [i] = false;
470-
}
471-
}
472-
};
473-
474442
bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) {
475443
int n_embd = llama_model_n_embd(llama_get_model(ctx_llama));
476444

@@ -480,8 +448,8 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
480448
n_eval = n_batch;
481449
}
482450
float * embd = image_embed->embed+i*n_embd;
483-
llava_embd_batch llava_batch = llava_embd_batch(embd, n_eval, *n_past, 0);
484-
if (llama_decode(ctx_llama, llava_batch.batch)) {
451+
llama_batch_ext_ptr batch(llama_batch_ext_init_from_embd(embd, n_eval, 0, 0));
452+
if (llama_decode_ext(ctx_llama, batch.get())) {
485453
LOG_ERR("%s : failed to eval\n", __func__);
486454
return false;
487455
}

examples/llava/qwen2vl-cli.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct lla
6666
memcpy(&batch_mrope_pos[n_eval * 2], &mrope_pos[img_tokens * 2 + processed], n_eval * sizeof(llama_pos));
6767
memcpy(&batch_mrope_pos[n_eval * 3], &mrope_pos[img_tokens * 3 + processed], n_eval * sizeof(llama_pos));
6868

69+
// TODO: move this to llama_batch_ext API
6970
llama_batch batch = {
7071
int32_t(n_eval), // n_tokens
7172
nullptr, // token

examples/lookahead/lookahead.cpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ int main(int argc, char ** argv) {
115115
// seq_id == 0 : the current input token
116116
// seq_id [1, W] : tokens from the past N - 1 Jacobi iterations
117117
// seq_id [W + 1, W + G] : verification n-grams
118-
llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1);
118+
llama_batch_ext * batch = llama_batch_ext_init(params.n_ctx, W + G + 1);
119119

120120
// target model sampling context
121121
struct common_sampler * smpl = common_sampler_init(model, params.sampling);
@@ -204,10 +204,10 @@ int main(int argc, char ** argv) {
204204
// V V V V V V
205205
// id
206206
{
207-
common_batch_clear(batch);
207+
llama_batch_ext_clear(batch);
208208

209209
// current token - first token of the first level
210-
common_batch_add(batch, id, n_past, seq_id_all, true);
210+
llama_batch_ext_add_text(batch, id, n_past, seq_id_all.data(), seq_id_all.size(), true);
211211

212212
// verification n-grams - queue this before the lookahead tokens for less KV cache fragmentation
213213
{
@@ -230,9 +230,10 @@ int main(int argc, char ** argv) {
230230
const llama_token t = ngrams_observed.tokens[idx + j];
231231

232232
ngrams_cur[g].tokens [j + 1] = t;
233-
ngrams_cur[g].i_batch[j + 1] = batch.n_tokens;
233+
ngrams_cur[g].i_batch[j + 1] = llama_batch_ext_get_n_tokens(batch);
234234

235-
common_batch_add(batch, t, n_past + j + 1, { W + 1 + g }, true);
235+
llama_seq_id seq_id = W + 1 + g;
236+
llama_batch_ext_add_text(batch, t, n_past + j + 1, &seq_id, 1, true);
236237
}
237238
}
238239
}
@@ -244,18 +245,20 @@ int main(int argc, char ** argv) {
244245
seq_id_look[j] = i + j + 1;
245246
}
246247

247-
common_batch_add(batch, tokens_j[0][i], n_past + i, seq_id_look, false);
248+
llama_batch_ext_add_text(batch, tokens_j[0][i], n_past + i,
249+
seq_id_look.data(), seq_id_look.size(), false);
248250
}
249251

250252
// fill the rest of the levels
251253
for (int j = 1; j < N - 1; j++) {
252254
for (int i = 0; i < W; i++) {
253-
common_batch_add(batch, tokens_j[j][i], n_past + j + i, { i + 1 }, j == N - 2);
255+
llama_seq_id seq_id = i + 1;
256+
llama_batch_ext_add_text(batch, tokens_j[j][i], n_past + j + i, &seq_id, 1, j == N - 2);
254257
}
255258
}
256259
}
257260

258-
if (llama_decode(ctx, batch) != 0) {
261+
if (llama_decode_ext(ctx, batch) != 0) {
259262
LOG_ERR("\n\n%s: llama_decode failed - increase KV cache size\n", __func__);
260263
return 1;
261264
}
@@ -475,7 +478,7 @@ int main(int argc, char ** argv) {
475478

476479
llama_kv_cache_view_free(&kvc_view);
477480

478-
llama_batch_free(batch);
481+
llama_batch_ext_free(batch);
479482

480483
llama_backend_free();
481484

0 commit comments

Comments
 (0)