Skip to content

Commit 7b28046

Browse files
Update llama-context-mmojo.cpp
Signed-off-by: Brad Hutchings <[email protected]>
1 parent 75b6fb7 commit 7b28046

File tree

1 file changed

+36
-13
lines changed

1 file changed

+36
-13
lines changed

src/llama-context-mmojo.cpp

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -508,12 +508,16 @@ enum llama_pooling_type llama_context::pooling_type() const {
508508
}
509509

510510
float * llama_context::get_logits() {
511+
output_reorder();
512+
511513
return logits;
512514
}
513515

514516
float * llama_context::get_logits_ith(int32_t i) {
515517
int64_t j = -1;
516518

519+
output_reorder();
520+
517521
try {
518522
if (logits == nullptr) {
519523
throw std::runtime_error("no logits");
@@ -550,12 +554,16 @@ float * llama_context::get_logits_ith(int32_t i) {
550554
}
551555

552556
float * llama_context::get_embeddings() {
557+
output_reorder();
558+
553559
return embd;
554560
}
555561

556562
float * llama_context::get_embeddings_ith(int32_t i) {
557563
int64_t j = -1;
558564

565+
output_reorder();
566+
559567
try {
560568
if (embd == nullptr) {
561569
throw std::runtime_error("no embeddings");
@@ -970,6 +978,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
970978

971979
// TODO: this clear of the buffer can easily be forgotten - need something better
972980
embd_seq.clear();
981+
output_swaps.clear();
973982

974983
bool did_optimize = false;
975984

@@ -1189,9 +1198,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
11891198
// make the outputs have the same order they had in the user-provided batch
11901199
// note: this is mostly relevant for recurrent models atm
11911200
if (!sorted_output) {
1192-
const uint32_t n_vocab = model.vocab.n_tokens();
1193-
const uint64_t n_embd = model.hparams.n_embd;
1194-
11951201
GGML_ASSERT((size_t) n_outputs == out_ids.size());
11961202

11971203
// TODO: is there something more efficient which also minimizes swaps?
@@ -1207,16 +1213,9 @@ int llama_context::decode(const llama_batch & batch_inp) {
12071213
continue;
12081214
}
12091215
std::swap(out_ids[i], out_ids[j_min]);
1210-
if (logits_size > 0) {
1211-
for (uint32_t k = 0; k < n_vocab; k++) {
1212-
std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
1213-
}
1214-
}
1215-
if (embd_size > 0) {
1216-
for (uint32_t k = 0; k < n_embd; k++) {
1217-
std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]);
1218-
}
1219-
}
1216+
1217+
// remember the swaps and apply them lazily upon logits/embeddings access
1218+
output_swaps.push_back({ i, j_min });
12201219
}
12211220

12221221
// mmojo-server START
@@ -1323,6 +1322,30 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
13231322
return n_outputs_max;
13241323
}
13251324

1325+
void llama_context::output_reorder() {
1326+
const uint32_t n_vocab = model.vocab.n_tokens();
1327+
const uint64_t n_embd = model.hparams.n_embd;
1328+
1329+
for (uint32_t s = 0; s < output_swaps.size(); ++s) {
1330+
const uint32_t i0 = output_swaps[s].i0;
1331+
const uint32_t i1 = output_swaps[s].i1;
1332+
1333+
if (logits_size > 0) {
1334+
for (uint32_t k = 0; k < n_vocab; k++) {
1335+
std::swap(logits[i0*n_vocab + k], logits[i1*n_vocab + k]);
1336+
}
1337+
}
1338+
1339+
if (embd_size > 0) {
1340+
for (uint32_t k = 0; k < n_embd; k++) {
1341+
std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]);
1342+
}
1343+
}
1344+
}
1345+
1346+
output_swaps.clear();
1347+
}
1348+
13261349
//
13271350
// graph
13281351
//

0 commit comments

Comments
 (0)