@@ -14756,7 +14756,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
14756
14756
}
14757
14757
}
14758
14758
14759
- for (int i = 0; i < n_seqs ; ++i) {
14759
+ for (int i = 0; i < n_tokens ; ++i) {
14760
14760
if (last_row[i] >= 0) {
14761
14761
data[i] = last_row[i];
14762
14762
}
@@ -14942,6 +14942,43 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
14942
14942
return n_outputs_max;
14943
14943
}
14944
14944
14945
+ // make the outputs have the same order they had in the user-provided batch
14946
+ static void llama_output_reorder(struct llama_context * ctx) {
14947
+ std::vector<size_t> & out_ids = ctx->sbatch.out_ids;
14948
+ if (!out_ids.empty()) {
14949
+ uint32_t n_vocab = ctx->model.hparams.n_vocab;
14950
+ uint32_t n_embd = ctx->model.hparams.n_embd;
14951
+ int32_t n_outputs = ctx->n_outputs;
14952
+ GGML_ASSERT((size_t) n_outputs == out_ids.size());
14953
+ // TODO: is there something more efficient which also minimizes swaps?
14954
+ // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
14955
+ for (int32_t i = 0; i < n_outputs - 1; ++i) {
14956
+ int32_t j_min = i;
14957
+ for (int32_t j = i + 1; j < n_outputs; ++j) {
14958
+ if (out_ids[j] < out_ids[j_min]) {
14959
+ j_min = j;
14960
+ }
14961
+ }
14962
+ if (j_min == i) { continue; }
14963
+ std::swap(out_ids[i], out_ids[j_min]);
14964
+ if (ctx->logits_size > 0) {
14965
+ for (uint32_t k = 0; k < n_vocab; k++) {
14966
+ std::swap(ctx->logits[i*n_vocab + k], ctx->logits[j_min*n_vocab + k]);
14967
+ }
14968
+ }
14969
+ if (ctx->embd_size > 0) {
14970
+ for (uint32_t k = 0; k < n_embd; k++) {
14971
+ std::swap(ctx->embd[i*n_embd + k], ctx->embd[j_min*n_embd + k]);
14972
+ }
14973
+ }
14974
+ }
14975
+ std::fill(ctx->output_ids.begin(), ctx->output_ids.end(), -1);
14976
+ for (int32_t i = 0; i < n_outputs; ++i) {
14977
+ ctx->output_ids[out_ids[i]] = i;
14978
+ }
14979
+ out_ids.clear();
14980
+ }
14981
+ }
14945
14982
14946
14983
static void llama_graph_compute(
14947
14984
llama_context & lctx,
@@ -15180,8 +15217,8 @@ static int llama_decode_internal(
15180
15217
auto & embd_seq_out = lctx.embd_seq;
15181
15218
embd_seq_out.clear();
15182
15219
15183
- for (uint32_t i = 0; i < n_tokens; i++ ) {
15184
- const llama_seq_id seq_id = ubatch.seq_id[i ][0];
15220
+ for (uint32_t s = 0; s < ubatch.n_seqs; ++s ) {
15221
+ const llama_seq_id seq_id = ubatch.seq_id[s ][0];
15185
15222
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
15186
15223
continue;
15187
15224
}
@@ -15631,44 +15668,6 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
15631
15668
}
15632
15669
}
15633
15670
15634
- // make the outputs have the same order they had in the user-provided batch
15635
- static void llama_reorder_outputs(struct llama_context * ctx) {
15636
- std::vector<size_t> & out_ids = ctx->sbatch.out_ids;
15637
- if (!out_ids.empty()) {
15638
- uint32_t n_vocab = ctx->model.hparams.n_vocab;
15639
- uint32_t n_embd = ctx->model.hparams.n_embd;
15640
- int32_t n_outputs = ctx->n_outputs;
15641
- GGML_ASSERT((size_t) n_outputs == out_ids.size());
15642
- // TODO: is there something more efficient which also minimizes swaps?
15643
- // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
15644
- for (int32_t i = 0; i < n_outputs - 1; ++i) {
15645
- int32_t j_min = i;
15646
- for (int32_t j = i + 1; j < n_outputs; ++j) {
15647
- if (out_ids[j] < out_ids[j_min]) {
15648
- j_min = j;
15649
- }
15650
- }
15651
- if (j_min == i) { continue; }
15652
- std::swap(out_ids[i], out_ids[j_min]);
15653
- if (ctx->logits_size > 0) {
15654
- for (uint32_t k = 0; k < n_vocab; k++) {
15655
- std::swap(ctx->logits[i*n_vocab + k], ctx->logits[j_min*n_vocab + k]);
15656
- }
15657
- }
15658
- if (ctx->embd_size > 0) {
15659
- for (uint32_t k = 0; k < n_embd; k++) {
15660
- std::swap(ctx->embd[i*n_embd + k], ctx->embd[j_min*n_embd + k]);
15661
- }
15662
- }
15663
- }
15664
- std::fill(ctx->output_ids.begin(), ctx->output_ids.end(), -1);
15665
- for (int32_t i = 0; i < n_outputs; ++i) {
15666
- ctx->output_ids[out_ids[i]] = i;
15667
- }
15668
- out_ids.clear();
15669
- }
15670
- }
15671
-
15672
15671
//
15673
15672
// quantization
15674
15673
//
@@ -17855,7 +17854,7 @@ struct llama_data_write {
17855
17854
}
17856
17855
17857
17856
void write_output_ids(struct llama_context * ctx) {
17858
- llama_reorder_outputs (ctx);
17857
+ llama_output_reorder (ctx);
17859
17858
17860
17859
const uint32_t n_outputs = ctx->n_outputs;
17861
17860
@@ -18891,7 +18890,7 @@ float * llama_get_logits(struct llama_context * ctx) {
18891
18890
18892
18891
// reorder logits for backward compatibility
18893
18892
// TODO: maybe deprecate this
18894
- llama_reorder_outputs (ctx);
18893
+ llama_output_reorder (ctx);
18895
18894
18896
18895
return ctx->logits;
18897
18896
}
@@ -18939,7 +18938,7 @@ float * llama_get_embeddings(struct llama_context * ctx) {
18939
18938
18940
18939
// reorder embeddings for backward compatibility
18941
18940
// TODO: maybe deprecate this
18942
- llama_reorder_outputs (ctx);
18941
+ llama_output_reorder (ctx);
18943
18942
18944
18943
return ctx->embd;
18945
18944
}
0 commit comments