Skip to content

Commit cfd5a11

Browse files
committed
llama : rename llama_reorder_outputs to llama_output_reorder
Also move it closer to llama_output_reserve. * llama : fix pooled embeddings when using batches with equal_seqs
1 parent 5679a3b commit cfd5a11

File tree

1 file changed

+43
-44
lines changed

1 file changed

+43
-44
lines changed

src/llama.cpp

Lines changed: 43 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -14756,7 +14756,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
1475614756
}
1475714757
}
1475814758

14759-
for (int i = 0; i < n_seqs; ++i) {
14759+
for (int i = 0; i < n_tokens; ++i) {
1476014760
if (last_row[i] >= 0) {
1476114761
data[i] = last_row[i];
1476214762
}
@@ -14942,6 +14942,43 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
1494214942
return n_outputs_max;
1494314943
}
1494414944

14945+
// make the outputs have the same order they had in the user-provided batch
14946+
static void llama_output_reorder(struct llama_context * ctx) {
14947+
std::vector<size_t> & out_ids = ctx->sbatch.out_ids;
14948+
if (!out_ids.empty()) {
14949+
uint32_t n_vocab = ctx->model.hparams.n_vocab;
14950+
uint32_t n_embd = ctx->model.hparams.n_embd;
14951+
int32_t n_outputs = ctx->n_outputs;
14952+
GGML_ASSERT((size_t) n_outputs == out_ids.size());
14953+
// TODO: is there something more efficient which also minimizes swaps?
14954+
// selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
14955+
for (int32_t i = 0; i < n_outputs - 1; ++i) {
14956+
int32_t j_min = i;
14957+
for (int32_t j = i + 1; j < n_outputs; ++j) {
14958+
if (out_ids[j] < out_ids[j_min]) {
14959+
j_min = j;
14960+
}
14961+
}
14962+
if (j_min == i) { continue; }
14963+
std::swap(out_ids[i], out_ids[j_min]);
14964+
if (ctx->logits_size > 0) {
14965+
for (uint32_t k = 0; k < n_vocab; k++) {
14966+
std::swap(ctx->logits[i*n_vocab + k], ctx->logits[j_min*n_vocab + k]);
14967+
}
14968+
}
14969+
if (ctx->embd_size > 0) {
14970+
for (uint32_t k = 0; k < n_embd; k++) {
14971+
std::swap(ctx->embd[i*n_embd + k], ctx->embd[j_min*n_embd + k]);
14972+
}
14973+
}
14974+
}
14975+
std::fill(ctx->output_ids.begin(), ctx->output_ids.end(), -1);
14976+
for (int32_t i = 0; i < n_outputs; ++i) {
14977+
ctx->output_ids[out_ids[i]] = i;
14978+
}
14979+
out_ids.clear();
14980+
}
14981+
}
1494514982

1494614983
static void llama_graph_compute(
1494714984
llama_context & lctx,
@@ -15180,8 +15217,8 @@ static int llama_decode_internal(
1518015217
auto & embd_seq_out = lctx.embd_seq;
1518115218
embd_seq_out.clear();
1518215219

15183-
for (uint32_t i = 0; i < n_tokens; i++) {
15184-
const llama_seq_id seq_id = ubatch.seq_id[i][0];
15220+
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
15221+
const llama_seq_id seq_id = ubatch.seq_id[s][0];
1518515222
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
1518615223
continue;
1518715224
}
@@ -15631,44 +15668,6 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
1563115668
}
1563215669
}
1563315670

15634-
// make the outputs have the same order they had in the user-provided batch
15635-
static void llama_reorder_outputs(struct llama_context * ctx) {
15636-
std::vector<size_t> & out_ids = ctx->sbatch.out_ids;
15637-
if (!out_ids.empty()) {
15638-
uint32_t n_vocab = ctx->model.hparams.n_vocab;
15639-
uint32_t n_embd = ctx->model.hparams.n_embd;
15640-
int32_t n_outputs = ctx->n_outputs;
15641-
GGML_ASSERT((size_t) n_outputs == out_ids.size());
15642-
// TODO: is there something more efficient which also minimizes swaps?
15643-
// selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
15644-
for (int32_t i = 0; i < n_outputs - 1; ++i) {
15645-
int32_t j_min = i;
15646-
for (int32_t j = i + 1; j < n_outputs; ++j) {
15647-
if (out_ids[j] < out_ids[j_min]) {
15648-
j_min = j;
15649-
}
15650-
}
15651-
if (j_min == i) { continue; }
15652-
std::swap(out_ids[i], out_ids[j_min]);
15653-
if (ctx->logits_size > 0) {
15654-
for (uint32_t k = 0; k < n_vocab; k++) {
15655-
std::swap(ctx->logits[i*n_vocab + k], ctx->logits[j_min*n_vocab + k]);
15656-
}
15657-
}
15658-
if (ctx->embd_size > 0) {
15659-
for (uint32_t k = 0; k < n_embd; k++) {
15660-
std::swap(ctx->embd[i*n_embd + k], ctx->embd[j_min*n_embd + k]);
15661-
}
15662-
}
15663-
}
15664-
std::fill(ctx->output_ids.begin(), ctx->output_ids.end(), -1);
15665-
for (int32_t i = 0; i < n_outputs; ++i) {
15666-
ctx->output_ids[out_ids[i]] = i;
15667-
}
15668-
out_ids.clear();
15669-
}
15670-
}
15671-
1567215671
//
1567315672
// quantization
1567415673
//
@@ -17855,7 +17854,7 @@ struct llama_data_write {
1785517854
}
1785617855

1785717856
void write_output_ids(struct llama_context * ctx) {
17858-
llama_reorder_outputs(ctx);
17857+
llama_output_reorder(ctx);
1785917858

1786017859
const uint32_t n_outputs = ctx->n_outputs;
1786117860

@@ -18891,7 +18890,7 @@ float * llama_get_logits(struct llama_context * ctx) {
1889118890

1889218891
// reorder logits for backward compatibility
1889318892
// TODO: maybe deprecate this
18894-
llama_reorder_outputs(ctx);
18893+
llama_output_reorder(ctx);
1889518894

1889618895
return ctx->logits;
1889718896
}
@@ -18939,7 +18938,7 @@ float * llama_get_embeddings(struct llama_context * ctx) {
1893918938

1894018939
// reorder embeddings for backward compatibility
1894118940
// TODO: maybe deprecate this
18942-
llama_reorder_outputs(ctx);
18941+
llama_output_reorder(ctx);
1894318942

1894418943
return ctx->embd;
1894518944
}

0 commit comments

Comments
 (0)