@@ -565,7 +565,6 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
565565 }
566566
567567 for (int k = 0 ; k < batch_size; ++k) {
568- const int idx = seq*n_ctx + k;
569568 const llama_pos pos = j*n_batch + k;
570569 bool output = pos >= first;
571570 batch.add_text (tokens[seq_start + k], pos, seq, output);
@@ -876,7 +875,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
876875 }
877876
878877 for (size_t i = 0 ; i < hs_cur.common_prefix ; ++i) {
879- batch.add_text (hs_cur.seq_tokens [0 ][i], i, { s0 + 0 , s0 + 1 , s0 + 2 , s0 + 3 }, false );
878+ batch.add_text_multi_seq (hs_cur.seq_tokens [0 ][i], i, { s0 + 0 , s0 + 1 , s0 + 2 , s0 + 3 }, false );
880879 }
881880 llama_batch_ext_set_output_last (batch.get ());
882881 n_logits += 1 ;
@@ -886,7 +885,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
886885 // TODO: don't evaluate the last token of each sequence
887886 for (size_t i = hs_cur.common_prefix ; i < seq_tokens_size; ++i) {
888887 const bool needs_logits = i < seq_tokens_size - 1 ;
889- batch.add_text (hs_cur.seq_tokens [s][i], i, { s0 + s }, needs_logits);
888+ batch.add_text_multi_seq (hs_cur.seq_tokens [s][i], i, { s0 + s }, needs_logits);
890889 n_logits += needs_logits;
891890 }
892891 }
@@ -1155,15 +1154,15 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
11551154 }
11561155
11571156 for (size_t i = 0 ; i < data[i1].common_prefix ; ++i) {
1158- batch.add_text (data[i1].seq_tokens [0 ][i], i, { s0 + 0 , s0 + 1 }, false );
1157+ batch.add_text_multi_seq (data[i1].seq_tokens [0 ][i], i, { s0 + 0 , s0 + 1 }, false );
11591158 }
11601159 llama_batch_ext_set_output_last (batch.get ());
11611160 n_logits += 1 ;
11621161
11631162 for (int s = 0 ; s < 2 ; ++s) {
11641163 // TODO: end before the last token, no need to predict past the end of the sequences
11651164 for (size_t i = data[i1].common_prefix ; i < data[i1].seq_tokens [s].size (); ++i) {
1166- batch.add_text (data[i1].seq_tokens [s][i], i, { s0 + s }, true );
1165+ batch.add_text_multi_seq (data[i1].seq_tokens [s][i], i, { s0 + s }, true );
11671166 n_logits += 1 ;
11681167 }
11691168 }
@@ -1523,7 +1522,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
15231522
15241523 for (size_t i = 0 ; i < cur_task.common_prefix ; ++i) {
15251524 // llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false);
1526- batch.add_text (cur_task.seq_tokens [0 ][i], i, batch_indeces, false );
1525+ batch.add_text_multi_seq (cur_task.seq_tokens [0 ][i], i, batch_indeces, false );
15271526 }
15281527 llama_batch_ext_set_output_last (batch.get ()); // we need logits for the last token of the common prefix
15291528 n_logits += 1 ;
@@ -1533,7 +1532,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
15331532 // TODO: don't evaluate the last token of each sequence
15341533 for (size_t i = cur_task.common_prefix ; i < seq_tokens_size; ++i) {
15351534 const bool needs_logits = i < seq_tokens_size - 1 ;
1536- batch.add_text (cur_task.seq_tokens [s][i], i, { s0 + s }, needs_logits);
1535+ batch.add_text_multi_seq (cur_task.seq_tokens [s][i], i, { s0 + s }, needs_logits);
15371536 n_logits += needs_logits;
15381537 }
15391538 }
@@ -1760,7 +1759,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
17601759
17611760 batch.clear ();
17621761 for (int i = 0 ; i < batch_size; i++) {
1763- batch.add_text (tokens[batch_start + i], j*n_batch + i, {0 }, true );
1762+ batch.add_text_multi_seq (tokens[batch_start + i], j*n_batch + i, {0 }, true );
17641763 }
17651764
17661765 if (llama_decode_ext (ctx, batch.get ())) {
0 commit comments