@@ -363,15 +363,16 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
363363 // clear the KV cache
364364 llama_kv_self_clear (ctx);
365365
366- common_batch batch (n_batch, 1 );
366+ llama_batch_ext_ptr batch (llama_batch_ext_init ( n_batch, 1 ) );
367367
368368 for (int j = 0 ; j < num_batches; ++j) {
369369 const int batch_start = start + j * n_batch;
370370 const int batch_size = std::min (end - batch_start, n_batch);
371371
372- batch.clear ( );
372+ llama_batch_ext_clear ( batch.get () );
373373 for (int i = 0 ; i < batch_size; i++) {
374- batch.add_text (tokens[batch_start + i], j*n_batch + i, 0 , true );
374+ llama_seq_id seq_id = 0 ;
375+ llama_batch_ext_add_text (batch.get (), tokens[batch_start + i], j*n_batch + i, &seq_id, 1 , true );
375376 }
376377
377378 // LOG_DBG(" Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
@@ -501,7 +502,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
501502 GGML_ASSERT (n_batch < n_ctx || n_batch % n_ctx == 0 );
502503 GGML_ASSERT (params.n_ctx == n_seq * n_ctx);
503504
504- common_batch batch (std::min (n_batch, n_ctx*n_seq), 1 );
505+ llama_batch_ext_ptr batch (llama_batch_ext_init ( std::min (n_batch, n_ctx*n_seq), 1 ) );
505506
506507 std::vector<float > logits;
507508 if (num_batches > 1 ) {
@@ -552,7 +553,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
552553
553554 int n_outputs = 0 ;
554555
555- batch.clear ( );
556+ llama_batch_ext_clear ( batch.get () );
556557 for (int seq = 0 ; seq < n_seq_batch; seq++) {
557558 int seq_start = batch_start + seq*n_ctx;
558559
@@ -567,7 +568,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
567568 for (int k = 0 ; k < batch_size; ++k) {
568569 const llama_pos pos = j*n_batch + k;
569570 bool output = pos >= first;
570- batch.add_text ( tokens[seq_start + k], pos, seq, output);
571+ llama_batch_ext_add_text ( batch.get (), tokens[seq_start + k], pos, & seq, 1 , output);
571572
572573 n_outputs += output ? 1 : 0 ;
573574 }
@@ -649,26 +650,15 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
649650 return {tokens, ppl, logit_history, prob_history};
650651}
651652
652- static bool decode_helper (llama_context * ctx, common_batch & batch, std::vector<float > & batch_logits, int n_batch, int n_vocab) {
653- int prev_outputs = 0 ;
654- for (int i = 0 ; i < (int ) batch.get_n_tokens (); i += n_batch) {
655- const int n_tokens = std::min<int >(n_batch, batch.get_n_tokens () - i);
656-
657- common_batch batch_view = batch.get_view (i, n_tokens);
658-
659- const int ret = llama_decode_ext (ctx, batch_view.get ());
660- if (ret != 0 ) {
661- LOG_ERR (" failed to decode the batch, n_batch = %d, ret = %d\n " , n_batch, ret);
662- return false ;
663- }
664-
665- int n_outputs = batch_view.n_outputs ;
666-
667- memcpy (batch_logits.data () + size_t (prev_outputs)*n_vocab, llama_get_logits (ctx), size_t (n_outputs)*n_vocab*sizeof (float ));
668-
669- prev_outputs += n_outputs;
653+ static bool decode_helper (llama_context * ctx, llama_batch_ext_ptr & batch, std::vector<float > & batch_logits, size_t n_outputs, int n_vocab) {
654+ const int ret = llama_decode_ext (ctx, batch.get ());
655+ if (ret != 0 ) {
656+ LOG_ERR (" failed to decode the batch, ret = %d\n " , ret);
657+ return false ;
670658 }
671659
660+ memcpy (batch_logits.data (), llama_get_logits (ctx), n_outputs*n_vocab*sizeof (float ));
661+
672662 return true ;
673663}
674664
@@ -836,14 +826,12 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
836826 double acc = 0 .0f ;
837827
838828 const int n_ctx = llama_n_ctx (ctx);
839- const int n_batch = params.n_batch ;
840-
841829 const int n_vocab = llama_vocab_n_tokens (vocab);
842830
843831 const int max_tasks_per_batch = 32 ;
844832 const int max_seq = std::min (4 *max_tasks_per_batch, (int ) llama_n_seq_max (ctx));
845833
846- common_batch batch (n_ctx, 4 );
834+ llama_batch_ext_ptr batch (llama_batch_ext_init ( n_ctx, 4 ) );
847835
848836 std::vector<float > tok_logits (n_vocab);
849837 // TODO: this could be made smaller; it's currently the worst-case size
@@ -859,7 +847,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
859847 size_t i1 = i0;
860848 size_t i_logits = 0 ; // this tells us how many logits were needed before this point in the batch
861849
862- batch.clear ( );
850+ llama_batch_ext_clear ( batch.get () );
863851
864852 // batch as much tasks as possible into the available context
865853 // each task has 4 unique sequence ids - one for each ending
@@ -875,7 +863,8 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
875863 }
876864
877865 for (size_t i = 0 ; i < hs_cur.common_prefix ; ++i) {
878- batch.add_text_multi_seq (hs_cur.seq_tokens [0 ][i], i, { s0 + 0 , s0 + 1 , s0 + 2 , s0 + 3 }, false );
866+ std::vector<llama_seq_id> seq_ids = { s0 + 0 , s0 + 1 , s0 + 2 , s0 + 3 };
867+ llama_batch_ext_add_text (batch.get (), hs_cur.seq_tokens [0 ][i], i, seq_ids.data (), seq_ids.size (), false );
879868 }
880869 llama_batch_ext_set_output_last (batch.get ());
881870 n_logits += 1 ;
@@ -885,7 +874,8 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
885874 // TODO: don't evaluate the last token of each sequence
886875 for (size_t i = hs_cur.common_prefix ; i < seq_tokens_size; ++i) {
887876 const bool needs_logits = i < seq_tokens_size - 1 ;
888- batch.add_text_multi_seq (hs_cur.seq_tokens [s][i], i, { s0 + s }, needs_logits);
877+ llama_seq_id seq_id = s0 + s;
878+ llama_batch_ext_add_text (batch.get (), hs_cur.seq_tokens [s][i], i, &seq_id, 1 , needs_logits);
889879 n_logits += needs_logits;
890880 }
891881 }
@@ -907,7 +897,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
907897 llama_kv_self_clear (ctx);
908898
909899 // decode all tasks [i0, i1)
910- if (!decode_helper (ctx, batch, batch_logits, n_batch , n_vocab)) {
900+ if (!decode_helper (ctx, batch, batch_logits, i_logits , n_vocab)) {
911901 LOG_ERR (" %s: llama_decode() failed\n " , __func__);
912902 return ;
913903 }
@@ -1118,14 +1108,12 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
11181108 LOG_INF (" %s : calculating winogrande score over selected tasks.\n " , __func__);
11191109
11201110 const int n_ctx = llama_n_ctx (ctx);
1121- const int n_batch = params.n_batch ;
1122-
11231111 const int n_vocab = llama_vocab_n_tokens (vocab);
11241112
11251113 const int max_tasks_per_batch = 128 ;
11261114 const int max_seq = std::min (2 *max_tasks_per_batch, (int ) llama_n_seq_max (ctx));
11271115
1128- common_batch batch (n_ctx, 2 );
1116+ llama_batch_ext_ptr batch (llama_batch_ext_init ( n_ctx, 2 ) );
11291117
11301118 std::vector<float > tok_logits (n_vocab);
11311119 // TODO: this could be made smaller; it's currently the worst-case size
@@ -1144,7 +1132,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
11441132 size_t i1 = i0;
11451133 size_t i_logits = 0 ;
11461134
1147- batch.clear ( );
1135+ llama_batch_ext_clear ( batch.get () );
11481136
11491137 while (n_cur + (int ) data[i1].required_tokens <= n_ctx) {
11501138 int n_logits = 0 ;
@@ -1154,15 +1142,17 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
11541142 }
11551143
11561144 for (size_t i = 0 ; i < data[i1].common_prefix ; ++i) {
1157- batch.add_text_multi_seq (data[i1].seq_tokens [0 ][i], i, { s0 + 0 , s0 + 1 }, false );
1145+ std::vector<llama_seq_id> seq_ids{ s0 + 0 , s0 + 1 };
1146+ llama_batch_ext_add_text (batch.get (), data[i1].seq_tokens [0 ][i], i, seq_ids.data (), seq_ids.size (), false );
11581147 }
11591148 llama_batch_ext_set_output_last (batch.get ());
11601149 n_logits += 1 ;
11611150
11621151 for (int s = 0 ; s < 2 ; ++s) {
11631152 // TODO: end before the last token, no need to predict past the end of the sequences
11641153 for (size_t i = data[i1].common_prefix ; i < data[i1].seq_tokens [s].size (); ++i) {
1165- batch.add_text_multi_seq (data[i1].seq_tokens [s][i], i, { s0 + s }, true );
1154+ llama_seq_id seq_id = s0 + s;
1155+ llama_batch_ext_add_text (batch.get (), data[i1].seq_tokens [s][i], i, &seq_id, 1 , true );
11661156 n_logits += 1 ;
11671157 }
11681158 }
@@ -1184,7 +1174,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
11841174 llama_kv_self_clear (ctx);
11851175
11861176 // decode all tasks [i0, i1)
1187- if (!decode_helper (ctx, batch, batch_logits, n_batch , n_vocab)) {
1177+ if (!decode_helper (ctx, batch, batch_logits, i_logits , n_vocab)) {
11881178 LOG_ERR (" %s: llama_decode() failed\n " , __func__);
11891179 return ;
11901180 }
@@ -1472,14 +1462,12 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
14721462 LOG (" \n task\t acc_norm\n " );
14731463
14741464 const int n_ctx = llama_n_ctx (ctx);
1475- const int n_batch = params.n_batch ;
1476-
14771465 const int n_vocab = llama_vocab_n_tokens (vocab);
14781466
14791467 const int max_tasks_per_batch = 32 ;
14801468 const int max_seq = std::min (4 *max_tasks_per_batch, (int ) llama_n_seq_max (ctx));
14811469
1482- common_batch batch (n_ctx, max_seq);
1470+ llama_batch_ext_ptr batch (llama_batch_ext_init ( n_ctx, max_seq) );
14831471
14841472 std::vector<float > tok_logits (n_vocab);
14851473 std::vector<float > batch_logits (size_t (n_ctx)*n_vocab);
@@ -1499,7 +1487,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
14991487 size_t i1 = i0;
15001488 size_t i_logits = 0 ; // this tells us how many logits were needed before this point in the batch
15011489
1502- batch.clear ( );
1490+ llama_batch_ext_clear ( batch.get () );
15031491
15041492 // batch as much tasks as possible into the available context
15051493 // each task has 4 unique sequence ids - one for each ending
@@ -1518,11 +1506,12 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
15181506 if (int (batch_indeces.size ()) != num_answers) {
15191507 batch_indeces.resize (num_answers);
15201508 }
1521- for (int s = 0 ; s < num_answers; ++s) batch_indeces[s] = s0 + s;
1509+ for (int s = 0 ; s < num_answers; ++s) {
1510+ batch_indeces[s] = s0 + s;
1511+ }
15221512
15231513 for (size_t i = 0 ; i < cur_task.common_prefix ; ++i) {
1524- // llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false);
1525- batch.add_text_multi_seq (cur_task.seq_tokens [0 ][i], i, batch_indeces, false );
1514+ llama_batch_ext_add_text (batch.get (), cur_task.seq_tokens [0 ][i], i, batch_indeces.data (), batch_indeces.size (), false );
15261515 }
15271516 llama_batch_ext_set_output_last (batch.get ()); // we need logits for the last token of the common prefix
15281517 n_logits += 1 ;
@@ -1532,7 +1521,8 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
15321521 // TODO: don't evaluate the last token of each sequence
15331522 for (size_t i = cur_task.common_prefix ; i < seq_tokens_size; ++i) {
15341523 const bool needs_logits = i < seq_tokens_size - 1 ;
1535- batch.add_text_multi_seq (cur_task.seq_tokens [s][i], i, { s0 + s }, needs_logits);
1524+ llama_seq_id seq_id = { s0 + s };
1525+ llama_batch_ext_add_text (batch.get (), cur_task.seq_tokens [s][i], i, &seq_id, 1 , needs_logits);
15361526 n_logits += needs_logits;
15371527 }
15381528 }
@@ -1556,7 +1546,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
15561546 llama_kv_self_clear (ctx);
15571547
15581548 // decode all tasks [i0, i1)
1559- if (!decode_helper (ctx, batch, batch_logits, n_batch , n_vocab)) {
1549+ if (!decode_helper (ctx, batch, batch_logits, i_logits , n_vocab)) {
15601550 LOG_ERR (" %s: llama_decode() failed\n " , __func__);
15611551 return ;
15621552 }
@@ -1743,7 +1733,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
17431733 // clear the KV cache
17441734 llama_kv_self_clear (ctx);
17451735
1746- common_batch batch (n_batch, 1 );
1736+ llama_batch_ext_ptr batch (llama_batch_ext_init ( n_batch, 1 ) );
17471737
17481738 for (int j = 0 ; j < num_batches; ++j) {
17491739 const int batch_start = start + j * n_batch;
@@ -1757,9 +1747,10 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
17571747 tokens[batch_start] = llama_vocab_bos (vocab);
17581748 }
17591749
1760- batch.clear ( );
1750+ llama_batch_ext_clear ( batch.get () );
17611751 for (int i = 0 ; i < batch_size; i++) {
1762- batch.add_text_multi_seq (tokens[batch_start + i], j*n_batch + i, {0 }, true );
1752+ llama_seq_id seq_id = 0 ;
1753+ llama_batch_ext_add_text (batch.get (), tokens[batch_start + i], j*n_batch + i, &seq_id, 1 , true );
17631754 }
17641755
17651756 if (llama_decode_ext (ctx, batch.get ())) {
0 commit comments