Skip to content

Commit e4868d1

Browse files
authored
context : perform output reorder lazily upon access after sync (ggml-org#14853)
* context : perform output reorder after lazily upon access after sync ggml-ci * cont : add TODO
1 parent 820de57 commit e4868d1

File tree

3 files changed

+47
-13
lines changed

3 files changed

+47
-13
lines changed

include/llama.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -956,6 +956,7 @@ extern "C" {
956956
// in the order they have appeared in the batch.
957957
// Rows: number of tokens for which llama_batch.logits[i] != 0
958958
// Cols: n_vocab
959+
// TODO: deprecate in favor of llama_get_logits_ith() (ref: https://github.com/ggml-org/llama.cpp/pull/14853#issuecomment-3113143522)
959960
LLAMA_API float * llama_get_logits(struct llama_context * ctx);
960961

961962
// Logits for the ith token. For positive indices, Equivalent to:
@@ -970,6 +971,7 @@ extern "C" {
970971
// in the order they have appeared in the batch.
971972
// shape: [n_outputs*n_embd]
972973
// Otherwise, returns NULL.
974+
// TODO: deprecate in favor of llama_get_embeddings_ith() (ref: https://github.com/ggml-org/llama.cpp/pull/14853#issuecomment-3113143522)
973975
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
974976

975977
// Get the embeddings for the ith token. For positive indices, Equivalent to:

src/llama-context.cpp

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -508,12 +508,16 @@ enum llama_pooling_type llama_context::pooling_type() const {
508508
}
509509

510510
float * llama_context::get_logits() {
511+
output_reorder();
512+
511513
return logits;
512514
}
513515

514516
float * llama_context::get_logits_ith(int32_t i) {
515517
int64_t j = -1;
516518

519+
output_reorder();
520+
517521
try {
518522
if (logits == nullptr) {
519523
throw std::runtime_error("no logits");
@@ -550,12 +554,16 @@ float * llama_context::get_logits_ith(int32_t i) {
550554
}
551555

552556
float * llama_context::get_embeddings() {
557+
output_reorder();
558+
553559
return embd;
554560
}
555561

556562
float * llama_context::get_embeddings_ith(int32_t i) {
557563
int64_t j = -1;
558564

565+
output_reorder();
566+
559567
try {
560568
if (embd == nullptr) {
561569
throw std::runtime_error("no embeddings");
@@ -970,6 +978,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
970978

971979
// TODO: this clear of the buffer can easily be forgotten - need something better
972980
embd_seq.clear();
981+
output_swaps.clear();
973982

974983
bool did_optimize = false;
975984

@@ -1189,9 +1198,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
11891198
// make the outputs have the same order they had in the user-provided batch
11901199
// note: this is mostly relevant for recurrent models atm
11911200
if (!sorted_output) {
1192-
const uint32_t n_vocab = model.vocab.n_tokens();
1193-
const uint64_t n_embd = model.hparams.n_embd;
1194-
11951201
GGML_ASSERT((size_t) n_outputs == out_ids.size());
11961202

11971203
// TODO: is there something more efficient which also minimizes swaps?
@@ -1207,16 +1213,9 @@ int llama_context::decode(const llama_batch & batch_inp) {
12071213
continue;
12081214
}
12091215
std::swap(out_ids[i], out_ids[j_min]);
1210-
if (logits_size > 0) {
1211-
for (uint32_t k = 0; k < n_vocab; k++) {
1212-
std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
1213-
}
1214-
}
1215-
if (embd_size > 0) {
1216-
for (uint32_t k = 0; k < n_embd; k++) {
1217-
std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]);
1218-
}
1219-
}
1216+
1217+
// remember the swaps and apply them lazily upon logits/embeddings access
1218+
output_swaps.push_back({ i, j_min });
12201219
}
12211220

12221221
std::fill(output_ids.begin(), output_ids.end(), -1);
@@ -1307,6 +1306,30 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
13071306
return n_outputs_max;
13081307
}
13091308

1309+
void llama_context::output_reorder() {
1310+
const uint32_t n_vocab = model.vocab.n_tokens();
1311+
const uint64_t n_embd = model.hparams.n_embd;
1312+
1313+
for (uint32_t s = 0; s < output_swaps.size(); ++s) {
1314+
const uint32_t i0 = output_swaps[s].i0;
1315+
const uint32_t i1 = output_swaps[s].i1;
1316+
1317+
if (logits_size > 0) {
1318+
for (uint32_t k = 0; k < n_vocab; k++) {
1319+
std::swap(logits[i0*n_vocab + k], logits[i1*n_vocab + k]);
1320+
}
1321+
}
1322+
1323+
if (embd_size > 0) {
1324+
for (uint32_t k = 0; k < n_embd; k++) {
1325+
std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]);
1326+
}
1327+
}
1328+
}
1329+
1330+
output_swaps.clear();
1331+
}
1332+
13101333
//
13111334
// graph
13121335
//

src/llama-context.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,8 @@ struct llama_context {
181181
// Returns max number of outputs for which space was reserved.
182182
uint32_t output_reserve(int32_t n_outputs);
183183

184+
void output_reorder();
185+
184186
//
185187
// graph
186188
//
@@ -250,6 +252,13 @@ struct llama_context {
250252

251253
std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
252254

255+
struct swap_info {
256+
uint32_t i0;
257+
uint32_t i1;
258+
};
259+
260+
std::vector<swap_info> output_swaps;
261+
253262
ggml_backend_sched_ptr sched;
254263

255264
ggml_backend_t backend_cpu = nullptr;

0 commit comments

Comments
 (0)