Skip to content

Commit c3dd790

Browse files
committed
fix llama_batch_ext_init_from_text
1 parent 65f0184 commit c3dd790

File tree

18 files changed

+40
-27
lines changed

18 files changed

+40
-27
lines changed

common/common.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,7 +1014,7 @@ struct common_init_result common_init_from_params(common_params & params) {
10141014
}
10151015

10161016
if (llama_model_has_encoder(model)) {
1017-
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tmp.data(), tmp.size(), 0, 0));
1017+
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tmp.data(), tmp.size(), 0, 0, true));
10181018
llama_encode_ext(lctx, batch.get());
10191019
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
10201020
if (decoder_start_token_id == LLAMA_TOKEN_NULL) {
@@ -1024,7 +1024,7 @@ struct common_init_result common_init_from_params(common_params & params) {
10241024
tmp.push_back(decoder_start_token_id);
10251025
}
10261026
if (llama_model_has_decoder(model)) {
1027-
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
1027+
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0, true));
10281028
llama_decode_ext(lctx, batch.get());
10291029
}
10301030
llama_kv_self_clear(lctx);

examples/cvector-generator/cvector-generator.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,8 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) {
343343

344344
static bool get_hidden_layers(llama_context * ctx, std::vector<llama_token> & tokens) {
345345
llama_kv_self_clear(ctx);
346-
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0));
346+
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0, true));
347+
llama_batch_ext_set_output_last(batch.get());
347348
if (llama_decode_ext(ctx, batch.get())) {
348349
fprintf(stderr, "%s : failed to eval\n", __func__);
349350
return false;

examples/eval-callback/eval-callback.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ static bool run(llama_context * ctx, const common_params & params) {
134134

135135
std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, add_bos);
136136

137-
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0));
137+
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0, true));
138138
if (llama_decode_ext(ctx, batch.get())) {
139139
LOG_ERR("%s : failed to eval\n", __func__);
140140
return false;

examples/infill/infill.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ int main(int argc, char ** argv) {
353353

354354
LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str());
355355

356-
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&embd[i], n_eval, 0, 0));
356+
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&embd[i], n_eval, n_past, 0, true));
357357
if (llama_decode_ext(ctx, batch.get())) {
358358
LOG_ERR("%s : failed to eval\n", __func__);
359359
return 1;

examples/llama-bench/llama-bench.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1444,7 +1444,8 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_th
14441444
for (int i = 1; i < n_tokens; i++) {
14451445
tokens[i] = std::rand() % n_vocab;
14461446
}
1447-
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), n_tokens, 0, 0));
1447+
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), n_tokens, 0, 0, true));
1448+
llama_batch_ext_set_output_last(batch.get());
14481449
llama_decode_ext(ctx, batch.get());
14491450
n_processed += n_tokens;
14501451
}
@@ -1462,7 +1463,7 @@ static void test_gen(llama_context * ctx, int n_gen, int n_threads) {
14621463
llama_token token = llama_vocab_get_add_bos(vocab) ? llama_vocab_bos(vocab) : std::rand() % n_vocab;
14631464

14641465
for (int i = 0; i < n_gen; i++) {
1465-
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&token, 1, 0, 0));
1466+
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&token, 1, 0, 0, true));
14661467
llama_decode_ext(ctx, batch.get());
14671468
llama_synchronize(ctx);
14681469
token = std::rand() % n_vocab;

examples/llava/llava-cli.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke
2020
if (n_eval > n_batch) {
2121
n_eval = n_batch;
2222
}
23-
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&tokens[i], n_eval, *n_past, 0));
23+
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&tokens[i], n_eval, *n_past, 0, true));
2424
if (llama_decode_ext(ctx_llama, batch.get())) {
2525
LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
2626
return false;

examples/llava/minicpmv-cli.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke
101101
if (n_eval > n_batch) {
102102
n_eval = n_batch;
103103
}
104-
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&tokens[i], n_eval, *n_past, 0));
104+
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&tokens[i], n_eval, *n_past, 0, true));
105105
if (llama_decode_ext(ctx_llama, batch.get())) {
106106
LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
107107
return false;

examples/lookahead/lookahead.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ int main(int argc, char ** argv) {
9292
const auto t_enc_start = ggml_time_us();
9393

9494
// eval the prompt
95-
llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0));
96-
llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, 0, 0));
95+
llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true));
96+
llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, 0, 0, true));
9797
llama_decode_ext(ctx, batch0.get());
9898
llama_decode_ext(ctx, batch1.get());
9999

examples/lookup/lookup.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,8 @@ int main(int argc, char ** argv){
9191

9292
const auto t_enc_start = ggml_time_us();
9393

94-
llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0));
95-
llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, 0, 0));
94+
llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true));
95+
llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, 0, 0, true));
9696
llama_decode_ext(ctx, batch0.get());
9797
llama_decode_ext(ctx, batch1.get());
9898

examples/main/main.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -548,7 +548,7 @@ int main(int argc, char ** argv) {
548548
int enc_input_size = embd_inp.size();
549549
llama_token * enc_input_buf = embd_inp.data();
550550

551-
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(enc_input_buf, enc_input_size, 0, 0));
551+
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(enc_input_buf, enc_input_size, 0, 0, true));
552552
if (llama_decode_ext(ctx, batch.get())) {
553553
LOG_ERR("%s : failed to eval\n", __func__);
554554
return 1;
@@ -669,7 +669,8 @@ int main(int argc, char ** argv) {
669669

670670
LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str());
671671

672-
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&embd[i], n_eval, 0, 0));
672+
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&embd[i], n_eval, n_past, 0, true));
673+
llama_batch_ext_set_output_last(batch.get());
673674
if (llama_decode_ext(ctx, batch.get())) {
674675
LOG_ERR("%s : failed to eval\n", __func__);
675676
return 1;

0 commit comments

Comments
 (0)