Skip to content

Commit 65f0184

Browse files
committed
compile ok
1 parent 9fb2d81 commit 65f0184

File tree

9 files changed

+43
-26
lines changed

9 files changed

+43
-26
lines changed

common/common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,7 @@ struct common_batch {
607607
n_outputs++;
608608
}
609609
}
610-
void add_text(llama_token token, llama_pos pos, std::vector<llama_seq_id> seq_ids, bool logits) {
610+
void add_text_multi_seq(llama_token token, llama_pos pos, std::vector<llama_seq_id> seq_ids, bool logits) {
611611
llama_batch_ext_add_text(batch.get(), token, pos, seq_ids.data(), seq_ids.size(), logits);
612612
tokens.push_back({token, seq_ids[0], logits});
613613
if (logits) {

examples/llava/llava-cli.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke
2020
if (n_eval > n_batch) {
2121
n_eval = n_batch;
2222
}
23-
if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval))) {
23+
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&tokens[i], n_eval, *n_past, 0));
24+
if (llama_decode_ext(ctx_llama, batch.get())) {
2425
LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
2526
return false;
2627
}

examples/llava/minicpmv-cli.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke
101101
if (n_eval > n_batch) {
102102
n_eval = n_batch;
103103
}
104-
if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval))) {
104+
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&tokens[i], n_eval, *n_past, 0));
105+
if (llama_decode_ext(ctx_llama, batch.get())) {
105106
LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
106107
return false;
107108
}

examples/llava/qwen2vl-cli.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,16 +96,24 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke
9696
if (n_eval > n_batch) {
9797
n_eval = n_batch;
9898
}
99-
auto batch = llama_batch_get_one(&tokens[i], n_eval);
99+
100100
// TODO: add mrope pos ids somewhere else
101-
pos.resize(batch.n_tokens * 4);
101+
int n_tokens = n_eval;
102+
pos.resize(n_tokens * 4);
102103
std::fill(pos.begin(), pos.end(), 0);
103-
for (int j = 0; j < batch.n_tokens * 3; j ++) {
104-
pos[j] = *st_pos_id + (j % batch.n_tokens);
104+
for (int j = 0; j < n_tokens * 3; j ++) {
105+
pos[j] = *st_pos_id + (j % n_tokens);
105106
}
106-
batch.pos = pos.data();
107107

108-
if (llama_decode(ctx_llama, batch)) {
108+
llama_batch_ext_ptr batch(llama_batch_ext_init(n_eval, 1));
109+
for (int j = 0; j < n_eval; j++) {
110+
llama_token token = tokens[i + j];
111+
llama_seq_id seq_id = 0;
112+
llama_batch_ext_add_text(batch.get(), token, pos[j], &seq_id, 1, false);
113+
}
114+
llama_batch_ext_set_output_last(batch.get());
115+
116+
if (llama_decode_ext(ctx_llama, batch.get())) {
109117
LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
110118
return false;
111119
}

examples/lookahead/lookahead.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,10 @@ int main(int argc, char ** argv) {
9292
const auto t_enc_start = ggml_time_us();
9393

9494
// eval the prompt
95-
llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1));
96-
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1));
95+
llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0));
96+
llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, 0, 0));
97+
llama_decode_ext(ctx, batch0.get());
98+
llama_decode_ext(ctx, batch1.get());
9799

98100
for (int s = 1; s < W + G + 1; ++s) {
99101
llama_kv_self_seq_cp(ctx, 0, s, -1, -1);

examples/main/main.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -548,7 +548,8 @@ int main(int argc, char ** argv) {
548548
int enc_input_size = embd_inp.size();
549549
llama_token * enc_input_buf = embd_inp.data();
550550

551-
if (llama_encode(ctx, llama_batch_get_one(enc_input_buf, enc_input_size))) {
551+
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(enc_input_buf, enc_input_size, 0, 0));
552+
if (llama_decode_ext(ctx, batch.get())) {
552553
LOG_ERR("%s : failed to eval\n", __func__);
553554
return 1;
554555
}
@@ -668,7 +669,8 @@ int main(int argc, char ** argv) {
668669

669670
LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str());
670671

671-
if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval))) {
672+
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&embd[i], n_eval, 0, 0));
673+
if (llama_decode_ext(ctx, batch.get())) {
672674
LOG_ERR("%s : failed to eval\n", __func__);
673675
return 1;
674676
}

examples/perplexity/perplexity.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,6 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
565565
}
566566

567567
for (int k = 0; k < batch_size; ++k) {
568-
const int idx = seq*n_ctx + k;
569568
const llama_pos pos = j*n_batch + k;
570569
bool output = pos >= first;
571570
batch.add_text(tokens[seq_start + k], pos, seq, output);
@@ -876,7 +875,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
876875
}
877876

878877
for (size_t i = 0; i < hs_cur.common_prefix; ++i) {
879-
batch.add_text(hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false);
878+
batch.add_text_multi_seq(hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false);
880879
}
881880
llama_batch_ext_set_output_last(batch.get());
882881
n_logits += 1;
@@ -886,7 +885,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
886885
// TODO: don't evaluate the last token of each sequence
887886
for (size_t i = hs_cur.common_prefix; i < seq_tokens_size; ++i) {
888887
const bool needs_logits = i < seq_tokens_size - 1;
889-
batch.add_text(hs_cur.seq_tokens[s][i], i, { s0 + s }, needs_logits);
888+
batch.add_text_multi_seq(hs_cur.seq_tokens[s][i], i, { s0 + s }, needs_logits);
890889
n_logits += needs_logits;
891890
}
892891
}
@@ -1155,15 +1154,15 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
11551154
}
11561155

11571156
for (size_t i = 0; i < data[i1].common_prefix; ++i) {
1158-
batch.add_text(data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false);
1157+
batch.add_text_multi_seq(data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false);
11591158
}
11601159
llama_batch_ext_set_output_last(batch.get());
11611160
n_logits += 1;
11621161

11631162
for (int s = 0; s < 2; ++s) {
11641163
// TODO: end before the last token, no need to predict past the end of the sequences
11651164
for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) {
1166-
batch.add_text(data[i1].seq_tokens[s][i], i, { s0 + s }, true);
1165+
batch.add_text_multi_seq(data[i1].seq_tokens[s][i], i, { s0 + s }, true);
11671166
n_logits += 1;
11681167
}
11691168
}
@@ -1523,7 +1522,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
15231522

15241523
for (size_t i = 0; i < cur_task.common_prefix; ++i) {
15251524
//llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false);
1526-
batch.add_text(cur_task.seq_tokens[0][i], i, batch_indeces, false);
1525+
batch.add_text_multi_seq(cur_task.seq_tokens[0][i], i, batch_indeces, false);
15271526
}
15281527
llama_batch_ext_set_output_last(batch.get()); // we need logits for the last token of the common prefix
15291528
n_logits += 1;
@@ -1533,7 +1532,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
15331532
// TODO: don't evaluate the last token of each sequence
15341533
for (size_t i = cur_task.common_prefix; i < seq_tokens_size; ++i) {
15351534
const bool needs_logits = i < seq_tokens_size - 1;
1536-
batch.add_text(cur_task.seq_tokens[s][i], i, { s0 + s }, needs_logits);
1535+
batch.add_text_multi_seq(cur_task.seq_tokens[s][i], i, { s0 + s }, needs_logits);
15371536
n_logits += needs_logits;
15381537
}
15391538
}
@@ -1760,7 +1759,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
17601759

17611760
batch.clear();
17621761
for (int i = 0; i < batch_size; i++) {
1763-
batch.add_text(tokens[batch_start + i], j*n_batch + i, {0}, true);
1762+
batch.add_text_multi_seq(tokens[batch_start + i], j*n_batch + i, {0}, true);
17641763
}
17651764

17661765
if (llama_decode_ext(ctx, batch.get())) {

examples/speculative-simple/speculative-simple.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,8 @@ int main(int argc, char ** argv) {
113113
struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling);
114114

115115
// eval the prompt
116-
llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1));
116+
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(inp.data(), inp.size() - 1, 0, 0));
117+
llama_decode_ext(ctx_tgt, batch.get());
117118

118119
// note: keep the last token separate!
119120
llama_token id_last = inp.back();

examples/speculative/speculative.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ int main(int argc, char ** argv) {
4545
}
4646

4747
common_init();
48-
#ifdef 0
48+
#if 0
4949
if (params.speculative.model.empty()) {
5050
LOG_ERR("%s: --model-draft is required\n", __func__);
5151
return 1;
@@ -166,9 +166,12 @@ int main(int argc, char ** argv) {
166166
const auto t_enc_start = ggml_time_us();
167167

168168
// eval the prompt with both models
169-
llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1));
170-
llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(), 1));
171-
llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input));
169+
llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0));
170+
llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, 0, 0));
171+
llama_batch_ext_ptr batch2(llama_batch_ext_init_from_text( inp.data(), n_input , 0, 0));
172+
llama_decode_ext(ctx_tgt, batch0);
173+
llama_decode_ext(ctx_tgt, batch1);
174+
llama_decode_ext(ctx_dft, batch2);
172175

173176
const auto t_enc_end = ggml_time_us();
174177

0 commit comments

Comments
 (0)