@@ -1427,7 +1427,7 @@ struct sql_printer : public printer {
14271427 }
14281428};
14291429
1430- static void test_prompt (llama_context * ctx, int n_prompt, int n_batch, int n_threads) {
1430+ static void test_prompt (llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) {
14311431 llama_set_n_threads (ctx, n_threads, n_threads);
14321432
14331433 const llama_model * model = llama_get_model (ctx);
@@ -1444,15 +1444,15 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_th
14441444 for (int i = 1 ; i < n_tokens; i++) {
14451445 tokens[i] = std::rand () % n_vocab;
14461446 }
1447- llama_batch_ext_ptr batch (llama_batch_ext_init_from_text (tokens.data (), n_tokens, 0 , 0 , true ));
1447+ llama_batch_ext_ptr batch (llama_batch_ext_init_from_text (tokens.data (), n_tokens, n_past + n_processed , 0 , true ));
14481448 llama_decode_ext (ctx, batch.get ());
14491449 n_processed += n_tokens;
14501450 }
14511451
14521452 llama_synchronize (ctx);
14531453}
14541454
1455- static void test_gen (llama_context * ctx, int n_gen, int n_threads) {
1455+ static void test_gen (llama_context * ctx, int n_gen, int n_past, int n_threads) {
14561456 llama_set_n_threads (ctx, n_threads, n_threads);
14571457
14581458 const llama_model * model = llama_get_model (ctx);
@@ -1462,7 +1462,7 @@ static void test_gen(llama_context * ctx, int n_gen, int n_threads) {
14621462 llama_token token = llama_vocab_get_add_bos (vocab) ? llama_vocab_bos (vocab) : std::rand () % n_vocab;
14631463
14641464 for (int i = 0 ; i < n_gen; i++) {
1465- llama_batch_ext_ptr batch (llama_batch_ext_init_from_text (&token, 1 , 0 , 0 , true ));
1465+ llama_batch_ext_ptr batch (llama_batch_ext_init_from_text (&token, 1 , n_past + i , 0 , true ));
14661466 llama_decode_ext (ctx, batch.get ());
14671467 llama_synchronize (ctx);
14681468 token = std::rand () % n_vocab;
@@ -1610,13 +1610,13 @@ int main(int argc, char ** argv) {
16101610 fprintf (stderr, " llama-bench: benchmark %d/%zu: warmup prompt run\n " , params_idx, params_count);
16111611 }
16121612 // test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads);
1613- test_prompt (ctx, t.n_prompt , t.n_batch , t.n_threads );
1613+ test_prompt (ctx, t.n_prompt , 0 , t.n_batch , t.n_threads );
16141614 }
16151615 if (t.n_gen > 0 ) {
16161616 if (params.progress ) {
16171617 fprintf (stderr, " llama-bench: benchmark %d/%zu: warmup generation run\n " , params_idx, params_count);
16181618 }
1619- test_gen (ctx, 1 , t.n_threads );
1619+ test_gen (ctx, 1 , 0 , t.n_threads );
16201620 }
16211621
16221622 for (int i = 0 ; i < params.reps ; i++) {
@@ -1629,14 +1629,14 @@ int main(int argc, char ** argv) {
16291629 fprintf (stderr, " llama-bench: benchmark %d/%zu: prompt run %d/%d\n " , params_idx, params_count,
16301630 i + 1 , params.reps );
16311631 }
1632- test_prompt (ctx, t.n_prompt , t.n_batch , t.n_threads );
1632+ test_prompt (ctx, t.n_prompt , 0 , t.n_batch , t.n_threads );
16331633 }
16341634 if (t.n_gen > 0 ) {
16351635 if (params.progress ) {
16361636 fprintf (stderr, " llama-bench: benchmark %d/%zu: generation run %d/%d\n " , params_idx, params_count,
16371637 i + 1 , params.reps );
16381638 }
1639- test_gen (ctx, t.n_gen , t.n_threads );
1639+ test_gen (ctx, t.n_gen , t.n_prompt , t. n_threads );
16401640 }
16411641
16421642 uint64_t t_ns = get_time_ns () - t_start;
0 commit comments