@@ -508,12 +508,16 @@ enum llama_pooling_type llama_context::pooling_type() const {
508508}
509509
510510float  * llama_context::get_logits () {
511+     output_reorder ();
512+ 
511513    return  logits;
512514}
513515
514516float  * llama_context::get_logits_ith (int32_t  i) {
515517    int64_t  j = -1 ;
516518
519+     output_reorder ();
520+ 
517521    try  {
518522        if  (logits == nullptr ) {
519523            throw  std::runtime_error (" no logits"  );
@@ -550,12 +554,16 @@ float * llama_context::get_logits_ith(int32_t i) {
550554}
551555
552556float  * llama_context::get_embeddings () {
557+     output_reorder ();
558+ 
553559    return  embd;
554560}
555561
556562float  * llama_context::get_embeddings_ith (int32_t  i) {
557563    int64_t  j = -1 ;
558564
565+     output_reorder ();
566+ 
559567    try  {
560568        if  (embd == nullptr ) {
561569            throw  std::runtime_error (" no embeddings"  );
@@ -970,6 +978,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
970978
971979    //  TODO: this clear of the buffer can easily be forgotten - need something better
972980    embd_seq.clear ();
981+     output_swaps.clear ();
973982
974983    bool  did_optimize = false ;
975984
@@ -1189,9 +1198,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
11891198        //  make the outputs have the same order they had in the user-provided batch
11901199        //  note: this is mostly relevant for recurrent models atm
11911200        if  (!sorted_output) {
1192-             const  uint32_t  n_vocab = model.vocab .n_tokens ();
1193-             const  uint64_t  n_embd  = model.hparams .n_embd ;
1194- 
11951201            GGML_ASSERT ((size_t ) n_outputs == out_ids.size ());
11961202
11971203            //  TODO: is there something more efficient which also minimizes swaps?
@@ -1207,16 +1213,9 @@ int llama_context::decode(const llama_batch & batch_inp) {
12071213                    continue ;
12081214                }
12091215                std::swap (out_ids[i], out_ids[j_min]);
1210-                 if  (logits_size > 0 ) {
1211-                     for  (uint32_t  k = 0 ; k < n_vocab; k++) {
1212-                         std::swap (logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
1213-                     }
1214-                 }
1215-                 if  (embd_size > 0 ) {
1216-                     for  (uint32_t  k = 0 ; k < n_embd; k++) {
1217-                         std::swap (embd[i*n_embd + k], embd[j_min*n_embd + k]);
1218-                     }
1219-                 }
1216+ 
1217+                 //  remember the swaps and apply them lazily upon logits/embeddings access
1218+                 output_swaps.push_back ({ i, j_min });
12201219            }
12211220
12221221            std::fill (output_ids.begin (), output_ids.end (), -1 );
@@ -1307,6 +1306,30 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
13071306    return  n_outputs_max;
13081307}
13091308
1309+ void  llama_context::output_reorder () {
1310+     const  uint32_t  n_vocab = model.vocab .n_tokens ();
1311+     const  uint64_t  n_embd  = model.hparams .n_embd ;
1312+ 
1313+     for  (uint32_t  s = 0 ; s < output_swaps.size (); ++s) {
1314+         const  uint32_t  i0 = output_swaps[s].i0 ;
1315+         const  uint32_t  i1 = output_swaps[s].i1 ;
1316+ 
1317+         if  (logits_size > 0 ) {
1318+             for  (uint32_t  k = 0 ; k < n_vocab; k++) {
1319+                 std::swap (logits[i0*n_vocab + k], logits[i1*n_vocab + k]);
1320+             }
1321+         }
1322+ 
1323+         if  (embd_size > 0 ) {
1324+             for  (uint32_t  k = 0 ; k < n_embd; k++) {
1325+                 std::swap (embd[i0*n_embd + k], embd[i1*n_embd + k]);
1326+             }
1327+         }
1328+     }
1329+ 
1330+     output_swaps.clear ();
1331+ }
1332+ 
13101333// 
13111334//  graph
13121335// 
0 commit comments