@@ -810,9 +810,6 @@ enum llama_pooling_type llama_context::pooling_type() const {
810810}
811811
812812float * llama_context::get_logits () {
813- // reorder logits for backward compatibility
814- output_reorder ();
815-
816813 return logits;
817814}
818815
@@ -855,9 +852,6 @@ float * llama_context::get_logits_ith(int32_t i) {
855852}
856853
857854float * llama_context::get_embeddings () {
858- // reorder embeddings for backward compatibility
859- output_reorder ();
860-
861855 return embd;
862856}
863857
@@ -1039,7 +1033,7 @@ int llama_context::encode(llama_batch & inp_batch) {
10391033
10401034 const int64_t n_embd = hparams.n_embd ;
10411035
1042- sbatch. from_batch (batch, n_embd, /* simple_split */ true , /* logits_all */ true );
1036+ llama_sbatch sbatch = llama_sbatch (batch, n_embd, /* simple_split */ true , /* logits_all */ true );
10431037
10441038 const llama_ubatch ubatch = sbatch.split_simple (n_tokens);
10451039
@@ -1230,13 +1224,7 @@ int llama_context::decode(llama_batch & inp_batch) {
12301224 n_outputs_all = 1 ;
12311225 }
12321226
1233- const bool logits_all = n_outputs_all == n_tokens_all;
1234-
1235- const bool is_recurrent = llama_model_is_recurrent (&model);
1236-
1237- sbatch.from_batch (batch, n_embd,
1238- /* simple_split */ !is_recurrent,
1239- /* logits_all */ logits_all);
1227+ llama_sbatch sbatch = kv_self->sbatch_init (batch, /* logits_all */ n_outputs_all == n_tokens_all);
12401228
12411229 // reserve output buffer
12421230 if (output_reserve (n_outputs_all) < n_outputs_all) {
@@ -1393,18 +1381,52 @@ int llama_context::decode(llama_batch & inp_batch) {
13931381 {
13941382 bool sorted_output = true ;
13951383
1396- GGML_ASSERT (sbatch.out_ids .size () == (size_t ) n_outputs_all);
1384+ auto & out_ids = sbatch.out_ids ;
1385+
1386+ GGML_ASSERT (out_ids.size () == (size_t ) n_outputs_all);
13971387
13981388 for (int64_t i = 0 ; i < n_outputs_all; ++i) {
1399- int64_t out_id = sbatch. out_ids [i];
1389+ int64_t out_id = out_ids[i];
14001390 output_ids[out_id] = i;
14011391 if (out_id != i) {
14021392 sorted_output = false ;
14031393 }
14041394 }
14051395
1406- if (sorted_output) {
1407- sbatch.out_ids .clear ();
1396+ // make the outputs have the same order they had in the user-provided batch
1397+ // note: this is mostly relevant for recurrent models atm
1398+ if (!sorted_output) {
1399+ const uint32_t n_vocab = model.vocab .n_tokens ();
1400+ const uint32_t n_embd = model.hparams .n_embd ;
1401+
1402+ GGML_ASSERT ((size_t ) n_outputs == out_ids.size ());
1403+
1404+ // TODO: is there something more efficient which also minimizes swaps?
1405+ // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
1406+ for (int32_t i = 0 ; i < n_outputs - 1 ; ++i) {
1407+ int32_t j_min = i;
1408+ for (int32_t j = i + 1 ; j < n_outputs; ++j) {
1409+ if (out_ids[j] < out_ids[j_min]) {
1410+ j_min = j;
1411+ }
1412+ }
1413+ if (j_min == i) { continue ; }
1414+ std::swap (out_ids[i], out_ids[j_min]);
1415+ if (logits_size > 0 ) {
1416+ for (uint32_t k = 0 ; k < n_vocab; k++) {
1417+ std::swap (logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
1418+ }
1419+ }
1420+ if (embd_size > 0 ) {
1421+ for (uint32_t k = 0 ; k < n_embd; k++) {
1422+ std::swap (embd[i*n_embd + k], embd[j_min*n_embd + k]);
1423+ }
1424+ }
1425+ }
1426+ std::fill (output_ids.begin (), output_ids.end (), -1 );
1427+ for (int32_t i = 0 ; i < n_outputs; ++i) {
1428+ output_ids[out_ids[i]] = i;
1429+ }
14081430 }
14091431 }
14101432
@@ -1515,44 +1537,6 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
15151537 return n_outputs_max;
15161538}
15171539
1518- void llama_context::output_reorder () {
1519- auto & out_ids = sbatch.out_ids ;
1520- if (!out_ids.empty ()) {
1521- const uint32_t n_vocab = model.vocab .n_tokens ();
1522- const uint32_t n_embd = model.hparams .n_embd ;
1523-
1524- GGML_ASSERT ((size_t ) n_outputs == out_ids.size ());
1525-
1526- // TODO: is there something more efficient which also minimizes swaps?
1527- // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
1528- for (int32_t i = 0 ; i < n_outputs - 1 ; ++i) {
1529- int32_t j_min = i;
1530- for (int32_t j = i + 1 ; j < n_outputs; ++j) {
1531- if (out_ids[j] < out_ids[j_min]) {
1532- j_min = j;
1533- }
1534- }
1535- if (j_min == i) { continue ; }
1536- std::swap (out_ids[i], out_ids[j_min]);
1537- if (logits_size > 0 ) {
1538- for (uint32_t k = 0 ; k < n_vocab; k++) {
1539- std::swap (logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
1540- }
1541- }
1542- if (embd_size > 0 ) {
1543- for (uint32_t k = 0 ; k < n_embd; k++) {
1544- std::swap (embd[i*n_embd + k], embd[j_min*n_embd + k]);
1545- }
1546- }
1547- }
1548- std::fill (output_ids.begin (), output_ids.end (), -1 );
1549- for (int32_t i = 0 ; i < n_outputs; ++i) {
1550- output_ids[out_ids[i]] = i;
1551- }
1552- out_ids.clear ();
1553- }
1554- }
1555-
15561540//
15571541// graph
15581542//
@@ -1993,8 +1977,6 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
19931977 {
19941978 LLAMA_LOG_DEBUG (" %s: - writing output ids\n " , __func__);
19951979
1996- output_reorder ();
1997-
19981980 const auto n_outputs = this ->n_outputs ;
19991981 const auto & output_ids = this ->output_ids ;
20001982
0 commit comments