Skip to content

Commit 07d84fa

Browse files
committed
fix missing n_past in various places
this is actually a revert of ggml-org@cda0e4b
1 parent 3294036 commit 07d84fa

File tree

6 files changed

+18
-18
lines changed

6 files changed

+18
-18
lines changed

examples/llama-bench/llama-bench.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1427,7 +1427,7 @@ struct sql_printer : public printer {
14271427
}
14281428
};
14291429

1430-
static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_threads) {
1430+
static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) {
14311431
llama_set_n_threads(ctx, n_threads, n_threads);
14321432

14331433
const llama_model * model = llama_get_model(ctx);
@@ -1444,15 +1444,15 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_th
14441444
for (int i = 1; i < n_tokens; i++) {
14451445
tokens[i] = std::rand() % n_vocab;
14461446
}
1447-
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), n_tokens, 0, 0, true));
1447+
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), n_tokens, n_past + n_processed, 0, true));
14481448
llama_decode_ext(ctx, batch.get());
14491449
n_processed += n_tokens;
14501450
}
14511451

14521452
llama_synchronize(ctx);
14531453
}
14541454

1455-
static void test_gen(llama_context * ctx, int n_gen, int n_threads) {
1455+
static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) {
14561456
llama_set_n_threads(ctx, n_threads, n_threads);
14571457

14581458
const llama_model * model = llama_get_model(ctx);
@@ -1462,7 +1462,7 @@ static void test_gen(llama_context * ctx, int n_gen, int n_threads) {
14621462
llama_token token = llama_vocab_get_add_bos(vocab) ? llama_vocab_bos(vocab) : std::rand() % n_vocab;
14631463

14641464
for (int i = 0; i < n_gen; i++) {
1465-
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&token, 1, 0, 0, true));
1465+
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&token, 1, n_past + i, 0, true));
14661466
llama_decode_ext(ctx, batch.get());
14671467
llama_synchronize(ctx);
14681468
token = std::rand() % n_vocab;
@@ -1610,13 +1610,13 @@ int main(int argc, char ** argv) {
16101610
fprintf(stderr, "llama-bench: benchmark %d/%zu: warmup prompt run\n", params_idx, params_count);
16111611
}
16121612
//test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads);
1613-
test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
1613+
test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads);
16141614
}
16151615
if (t.n_gen > 0) {
16161616
if (params.progress) {
16171617
fprintf(stderr, "llama-bench: benchmark %d/%zu: warmup generation run\n", params_idx, params_count);
16181618
}
1619-
test_gen(ctx, 1, t.n_threads);
1619+
test_gen(ctx, 1, 0, t.n_threads);
16201620
}
16211621

16221622
for (int i = 0; i < params.reps; i++) {
@@ -1629,14 +1629,14 @@ int main(int argc, char ** argv) {
16291629
fprintf(stderr, "llama-bench: benchmark %d/%zu: prompt run %d/%d\n", params_idx, params_count,
16301630
i + 1, params.reps);
16311631
}
1632-
test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
1632+
test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads);
16331633
}
16341634
if (t.n_gen > 0) {
16351635
if (params.progress) {
16361636
fprintf(stderr, "llama-bench: benchmark %d/%zu: generation run %d/%d\n", params_idx, params_count,
16371637
i + 1, params.reps);
16381638
}
1639-
test_gen(ctx, t.n_gen, t.n_threads);
1639+
test_gen(ctx, t.n_gen, t.n_prompt, t.n_threads);
16401640
}
16411641

16421642
uint64_t t_ns = get_time_ns() - t_start;

examples/lookahead/lookahead.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ int main(int argc, char ** argv) {
9292
const auto t_enc_start = ggml_time_us();
9393

9494
// eval the prompt
95-
llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true));
96-
llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, 0, 0, true));
95+
llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true));
96+
llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, n_input - 1, 0, true));
9797
llama_decode_ext(ctx, batch0.get());
9898
llama_decode_ext(ctx, batch1.get());
9999

examples/lookup/lookup.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,8 @@ int main(int argc, char ** argv){
9191

9292
const auto t_enc_start = ggml_time_us();
9393

94-
llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true));
95-
llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, 0, 0, true));
94+
llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true));
95+
llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, n_input - 1, 0, true));
9696
llama_decode_ext(ctx, batch0.get());
9797
llama_decode_ext(ctx, batch1.get());
9898

examples/save-load-state/save-load-state.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ int main(int argc, char ** argv) {
133133
result1 += next_token_str;
134134

135135
llama_batch_ext_clear(batch);
136-
llama_seq_id seq_id = 1;
136+
llama_seq_id seq_id = 0;
137137
llama_batch_ext_add_text(batch, next_token, 0, &seq_id, 1, true);
138138

139139
if (llama_decode_ext(ctx2, batch)) {
@@ -215,7 +215,7 @@ int main(int argc, char ** argv) {
215215
result2 += next_token_str;
216216

217217
llama_batch_ext_clear(batch);
218-
llama_seq_id seq_id = 1;
218+
llama_seq_id seq_id = 1; // seq 1 instead of 0
219219
llama_batch_ext_add_text(batch, next_token, 0, &seq_id, 1, true);
220220

221221
if (llama_decode_ext(ctx3, batch)) {

examples/simple/simple.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ int main(int argc, char ** argv) {
182182
// prepare the next batch with the sampled token
183183
llama_batch_ext_clear(batch);
184184
llama_seq_id seq_id = 0;
185-
llama_batch_ext_add_text(batch, new_token_id, 0, &seq_id, 1, true);
185+
llama_batch_ext_add_text(batch, new_token_id, n_pos, &seq_id, 1, true);
186186

187187
n_decode += 1;
188188
}

examples/speculative/speculative.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,9 +166,9 @@ int main(int argc, char ** argv) {
166166
const auto t_enc_start = ggml_time_us();
167167

168168
// eval the prompt with both models
169-
llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true));
170-
llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, 0, 0, true));
171-
llama_batch_ext_ptr batch2(llama_batch_ext_init_from_text( inp.data(), n_input , 0, 0, true));
169+
llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true));
170+
llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, n_input - 1, 0, true));
171+
llama_batch_ext_ptr batch2(llama_batch_ext_init_from_text( inp.data(), n_input , 0, 0, true));
172172
llama_decode_ext(ctx_tgt, batch0);
173173
llama_decode_ext(ctx_tgt, batch1);
174174
llama_decode_ext(ctx_dft, batch2);

0 commit comments

Comments
 (0)