@@ -508,12 +508,16 @@ enum llama_pooling_type llama_context::pooling_type() const {
508
508
}
509
509
510
510
float * llama_context::get_logits () {
511
+ output_reorder ();
512
+
511
513
return logits;
512
514
}
513
515
514
516
float * llama_context::get_logits_ith (int32_t i) {
515
517
int64_t j = -1 ;
516
518
519
+ output_reorder ();
520
+
517
521
try {
518
522
if (logits == nullptr ) {
519
523
throw std::runtime_error (" no logits" );
@@ -550,12 +554,16 @@ float * llama_context::get_logits_ith(int32_t i) {
550
554
}
551
555
552
556
float * llama_context::get_embeddings () {
557
+ output_reorder ();
558
+
553
559
return embd;
554
560
}
555
561
556
562
float * llama_context::get_embeddings_ith (int32_t i) {
557
563
int64_t j = -1 ;
558
564
565
+ output_reorder ();
566
+
559
567
try {
560
568
if (embd == nullptr ) {
561
569
throw std::runtime_error (" no embeddings" );
@@ -970,6 +978,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
970
978
971
979
// TODO: this clear of the buffer can easily be forgotten - need something better
972
980
embd_seq.clear ();
981
+ output_swaps.clear ();
973
982
974
983
bool did_optimize = false ;
975
984
@@ -1189,9 +1198,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
1189
1198
// make the outputs have the same order they had in the user-provided batch
1190
1199
// note: this is mostly relevant for recurrent models atm
1191
1200
if (!sorted_output) {
1192
- const uint32_t n_vocab = model.vocab .n_tokens ();
1193
- const uint64_t n_embd = model.hparams .n_embd ;
1194
-
1195
1201
GGML_ASSERT ((size_t ) n_outputs == out_ids.size ());
1196
1202
1197
1203
// TODO: is there something more efficient which also minimizes swaps?
@@ -1207,16 +1213,9 @@ int llama_context::decode(const llama_batch & batch_inp) {
1207
1213
continue ;
1208
1214
}
1209
1215
std::swap (out_ids[i], out_ids[j_min]);
1210
- if (logits_size > 0 ) {
1211
- for (uint32_t k = 0 ; k < n_vocab; k++) {
1212
- std::swap (logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
1213
- }
1214
- }
1215
- if (embd_size > 0 ) {
1216
- for (uint32_t k = 0 ; k < n_embd; k++) {
1217
- std::swap (embd[i*n_embd + k], embd[j_min*n_embd + k]);
1218
- }
1219
- }
1216
+
1217
+ // remember the swaps and apply them lazily upon logits/embeddings access
1218
+ output_swaps.push_back ({ i, j_min });
1220
1219
}
1221
1220
1222
1221
std::fill (output_ids.begin (), output_ids.end (), -1 );
@@ -1307,6 +1306,30 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
1307
1306
return n_outputs_max;
1308
1307
}
1309
1308
1309
+ void llama_context::output_reorder () {
1310
+ const uint32_t n_vocab = model.vocab .n_tokens ();
1311
+ const uint64_t n_embd = model.hparams .n_embd ;
1312
+
1313
+ for (uint32_t s = 0 ; s < output_swaps.size (); ++s) {
1314
+ const uint32_t i0 = output_swaps[s].i0 ;
1315
+ const uint32_t i1 = output_swaps[s].i1 ;
1316
+
1317
+ if (logits_size > 0 ) {
1318
+ for (uint32_t k = 0 ; k < n_vocab; k++) {
1319
+ std::swap (logits[i0*n_vocab + k], logits[i1*n_vocab + k]);
1320
+ }
1321
+ }
1322
+
1323
+ if (embd_size > 0 ) {
1324
+ for (uint32_t k = 0 ; k < n_embd; k++) {
1325
+ std::swap (embd[i0*n_embd + k], embd[i1*n_embd + k]);
1326
+ }
1327
+ }
1328
+ }
1329
+
1330
+ output_swaps.clear ();
1331
+ }
1332
+
1310
1333
//
1311
1334
// graph
1312
1335
//
0 commit comments