@@ -810,9 +810,6 @@ enum llama_pooling_type llama_context::pooling_type() const {
810810}
811811
812812float  * llama_context::get_logits () {
813-     //  reorder logits for backward compatibility
814-     output_reorder ();
815- 
816813    return  logits;
817814}
818815
@@ -855,9 +852,6 @@ float * llama_context::get_logits_ith(int32_t i) {
855852}
856853
857854float  * llama_context::get_embeddings () {
858-     //  reorder embeddings for backward compatibility
859-     output_reorder ();
860- 
861855    return  embd;
862856}
863857
@@ -1039,7 +1033,7 @@ int llama_context::encode(llama_batch & inp_batch) {
10391033
10401034    const  int64_t  n_embd = hparams.n_embd ;
10411035
1042-     sbatch. from_batch (batch, n_embd, /*  simple_split */   true , /*  logits_all */   true );
1036+     llama_sbatch  sbatch =  llama_sbatch (batch, n_embd, /*  simple_split */   true , /*  logits_all */   true );
10431037
10441038    const  llama_ubatch ubatch = sbatch.split_simple (n_tokens);
10451039
@@ -1230,13 +1224,7 @@ int llama_context::decode(llama_batch & inp_batch) {
12301224        n_outputs_all = 1 ;
12311225    }
12321226
1233-     const  bool  logits_all = n_outputs_all == n_tokens_all;
1234- 
1235-     const  bool  is_recurrent = llama_model_is_recurrent (&model);
1236- 
1237-     sbatch.from_batch (batch, n_embd,
1238-             /*  simple_split */   !is_recurrent,
1239-             /*  logits_all   */   logits_all);
1227+     llama_sbatch sbatch = kv_self->sbatch_init (batch, /*  logits_all */   n_outputs_all == n_tokens_all);
12401228
12411229    //  reserve output buffer
12421230    if  (output_reserve (n_outputs_all) < n_outputs_all) {
@@ -1393,18 +1381,52 @@ int llama_context::decode(llama_batch & inp_batch) {
13931381    {
13941382        bool  sorted_output = true ;
13951383
1396-         GGML_ASSERT (sbatch.out_ids .size () == (size_t ) n_outputs_all);
1384+         auto  & out_ids = sbatch.out_ids ;
1385+ 
1386+         GGML_ASSERT (out_ids.size () == (size_t ) n_outputs_all);
13971387
13981388        for  (int64_t  i = 0 ; i < n_outputs_all; ++i) {
1399-             int64_t  out_id = sbatch. out_ids [i];
1389+             int64_t  out_id = out_ids[i];
14001390            output_ids[out_id] = i;
14011391            if  (out_id != i) {
14021392                sorted_output = false ;
14031393            }
14041394        }
14051395
1406-         if  (sorted_output) {
1407-             sbatch.out_ids .clear ();
1396+         //  make the outputs have the same order they had in the user-provided batch
1397+         //  note: this is mostly relevant for recurrent models atm
1398+         if  (!sorted_output) {
1399+             const  uint32_t  n_vocab = model.vocab .n_tokens ();
1400+             const  uint32_t  n_embd  = model.hparams .n_embd ;
1401+ 
1402+             GGML_ASSERT ((size_t ) n_outputs == out_ids.size ());
1403+ 
1404+             //  TODO: is there something more efficient which also minimizes swaps?
1405+             //  selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
1406+             for  (int32_t  i = 0 ; i < n_outputs - 1 ; ++i) {
1407+                 int32_t  j_min = i;
1408+                 for  (int32_t  j = i + 1 ; j < n_outputs; ++j) {
1409+                     if  (out_ids[j] < out_ids[j_min]) {
1410+                         j_min = j;
1411+                     }
1412+                 }
1413+                 if  (j_min == i) { continue ; }
1414+                 std::swap (out_ids[i], out_ids[j_min]);
1415+                 if  (logits_size > 0 ) {
1416+                     for  (uint32_t  k = 0 ; k < n_vocab; k++) {
1417+                         std::swap (logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
1418+                     }
1419+                 }
1420+                 if  (embd_size > 0 ) {
1421+                     for  (uint32_t  k = 0 ; k < n_embd; k++) {
1422+                         std::swap (embd[i*n_embd + k], embd[j_min*n_embd + k]);
1423+                     }
1424+                 }
1425+             }
1426+             std::fill (output_ids.begin (), output_ids.end (), -1 );
1427+             for  (int32_t  i = 0 ; i < n_outputs; ++i) {
1428+                 output_ids[out_ids[i]] = i;
1429+             }
14081430        }
14091431    }
14101432
@@ -1515,44 +1537,6 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
15151537    return  n_outputs_max;
15161538}
15171539
1518- void  llama_context::output_reorder () {
1519-     auto  & out_ids = sbatch.out_ids ;
1520-     if  (!out_ids.empty ()) {
1521-         const  uint32_t  n_vocab = model.vocab .n_tokens ();
1522-         const  uint32_t  n_embd  = model.hparams .n_embd ;
1523- 
1524-         GGML_ASSERT ((size_t ) n_outputs == out_ids.size ());
1525- 
1526-         //  TODO: is there something more efficient which also minimizes swaps?
1527-         //  selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
1528-         for  (int32_t  i = 0 ; i < n_outputs - 1 ; ++i) {
1529-             int32_t  j_min = i;
1530-             for  (int32_t  j = i + 1 ; j < n_outputs; ++j) {
1531-                 if  (out_ids[j] < out_ids[j_min]) {
1532-                     j_min = j;
1533-                 }
1534-             }
1535-             if  (j_min == i) { continue ; }
1536-             std::swap (out_ids[i], out_ids[j_min]);
1537-             if  (logits_size > 0 ) {
1538-                 for  (uint32_t  k = 0 ; k < n_vocab; k++) {
1539-                     std::swap (logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
1540-                 }
1541-             }
1542-             if  (embd_size > 0 ) {
1543-                 for  (uint32_t  k = 0 ; k < n_embd; k++) {
1544-                     std::swap (embd[i*n_embd + k], embd[j_min*n_embd + k]);
1545-                 }
1546-             }
1547-         }
1548-         std::fill (output_ids.begin (), output_ids.end (), -1 );
1549-         for  (int32_t  i = 0 ; i < n_outputs; ++i) {
1550-             output_ids[out_ids[i]] = i;
1551-         }
1552-         out_ids.clear ();
1553-     }
1554- }
1555- 
15561540// 
15571541//  graph
15581542// 
@@ -1993,8 +1977,6 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
19931977    {
19941978        LLAMA_LOG_DEBUG (" %s: - writing output ids\n "  , __func__);
19951979
1996-         output_reorder ();
1997- 
19981980        const  auto  n_outputs    = this ->n_outputs ;
19991981        const  auto  & output_ids = this ->output_ids ;
20001982
0 commit comments