Skip to content

Commit 704a303

Browse files
committed
llama : fix Mamba session save and restore
1 parent 0dea426 commit 704a303

File tree

1 file changed

+51
-47
lines changed

1 file changed

+51
-47
lines changed

src/llama.cpp

Lines changed: 51 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -3508,11 +3508,11 @@ static bool llama_kv_cache_find_slot(
35083508
int32_t cell_id = s + min;
35093509
llama_kv_cell & cell = cache.cells[cell_id];
35103510

3511-
if (last_pos != cell.pos + (llama_pos) n_seq_tokens) {
3511+
if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
35123512
// What should happen when the pos backtracks or skips a value?
35133513
// Clearing the state mid-batch would require special-casing which isn't done.
3514-
LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n",
3515-
__func__, last_pos, cell.pos, batch.seq_id[s][0]);
3514+
LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
3515+
__func__, last_pos, cell.pos, batch.seq_id[s][0], n_seq_tokens);
35163516
}
35173517
cell.pos = last_pos;
35183518
cell.seq_id.clear();
@@ -15013,12 +15013,6 @@ static int llama_decode_internal(
1501315013

1501415014
const auto n_ubatch = cparams.n_ubatch;
1501515015

15016-
// TODO: simplify or deprecate
15017-
std::vector<llama_pos> pos;
15018-
std::vector<int32_t> n_seq_id;
15019-
std::vector<llama_seq_id *> seq_id_arr;
15020-
std::vector<std::vector<llama_seq_id>> seq_id;
15021-
1502215016
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
1502315017
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
1502415018

@@ -15636,6 +15630,44 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
1563615630
}
1563715631
}
1563815632

15633+
// make the outputs have the same order they had in the user-provided batch
15634+
static void llama_reorder_outputs(struct llama_context * ctx) {
15635+
std::vector<size_t> & out_ids = ctx->sbatch.out_ids;
15636+
if (!out_ids.empty()) {
15637+
uint32_t n_vocab = ctx->model.hparams.n_vocab;
15638+
uint32_t n_embd = ctx->model.hparams.n_embd;
15639+
int32_t n_outputs = ctx->n_outputs;
15640+
GGML_ASSERT((size_t) n_outputs == out_ids.size());
15641+
// TODO: is there something more efficient which also minimizes swaps?
15642+
// selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
15643+
for (int32_t i = 0; i < n_outputs - 1; ++i) {
15644+
int32_t j_min = i;
15645+
for (int32_t j = i + 1; j < n_outputs; ++j) {
15646+
if (out_ids[j] < out_ids[j_min]) {
15647+
j_min = j;
15648+
}
15649+
}
15650+
if (j_min == i) { continue; }
15651+
std::swap(out_ids[i], out_ids[j_min]);
15652+
if (ctx->logits_size > 0) {
15653+
for (uint32_t k = 0; k < n_vocab; k++) {
15654+
std::swap(ctx->logits[i*n_vocab + k], ctx->logits[j_min*n_vocab + k]);
15655+
}
15656+
}
15657+
if (ctx->embd_size > 0) {
15658+
for (uint32_t k = 0; k < n_embd; k++) {
15659+
std::swap(ctx->embd[i*n_embd + k], ctx->embd[j_min*n_embd + k]);
15660+
}
15661+
}
15662+
}
15663+
std::fill(ctx->output_ids.begin(), ctx->output_ids.end(), -1);
15664+
for (int32_t i = 0; i < n_outputs; ++i) {
15665+
ctx->output_ids[out_ids[i]] = i;
15666+
}
15667+
out_ids.clear();
15668+
}
15669+
}
15670+
1563915671
//
1564015672
// quantization
1564115673
//
@@ -17822,6 +17854,8 @@ struct llama_data_write {
1782217854
}
1782317855

1782417856
void write_output_ids(struct llama_context * ctx) {
17857+
llama_reorder_outputs(ctx);
17858+
1782517859
const uint32_t n_outputs = ctx->n_outputs;
1782617860

1782717861
std::vector<int32_t> output_pos;
@@ -18192,6 +18226,14 @@ struct llama_data_read {
1819218226
kv_self.used = cell_count;
1819318227
}
1819418228

18229+
if (kv_self.recurrent) {
18230+
for (uint32_t i = 0; i < cell_count; ++i) {
18231+
uint32_t cell_id = kv_self.head + i;
18232+
// make sure the recurrent states will keep their restored state
18233+
kv_self.cells[cell_id].src = cell_id;
18234+
}
18235+
}
18236+
1819518237
return true;
1819618238
}
1819718239

@@ -18843,44 +18885,6 @@ void llama_synchronize(struct llama_context * ctx) {
1884318885
ctx->t_compute_start_us = 0;
1884418886
}
1884518887

18846-
// make the outputs have the same order they had in the user-provided batch
18847-
static void llama_reorder_outputs(struct llama_context * ctx) {
18848-
std::vector<size_t> & out_ids = ctx->sbatch.out_ids;
18849-
if (!out_ids.empty()) {
18850-
uint32_t n_vocab = ctx->model.hparams.n_vocab;
18851-
uint32_t n_embd = ctx->model.hparams.n_embd;
18852-
int32_t n_outputs = ctx->n_outputs;
18853-
GGML_ASSERT((size_t) n_outputs == out_ids.size());
18854-
// TODO: is there something more efficient which also minimizes swaps?
18855-
// selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
18856-
for (int32_t i = 0; i < n_outputs - 1; ++i) {
18857-
int32_t j_min = i;
18858-
for (int32_t j = i + 1; j < n_outputs; ++j) {
18859-
if (out_ids[j] < out_ids[j_min]) {
18860-
j_min = j;
18861-
}
18862-
}
18863-
if (j_min == i) { continue; }
18864-
std::swap(out_ids[i], out_ids[j_min]);
18865-
if (ctx->logits_size > 0) {
18866-
for (uint32_t k = 0; k < n_vocab; k++) {
18867-
std::swap(ctx->logits[i*n_vocab + k], ctx->logits[j_min*n_vocab + k]);
18868-
}
18869-
}
18870-
if (ctx->embd_size > 0) {
18871-
for (uint32_t k = 0; k < n_embd; k++) {
18872-
std::swap(ctx->embd[i*n_embd + k], ctx->embd[j_min*n_embd + k]);
18873-
}
18874-
}
18875-
}
18876-
std::fill(ctx->output_ids.begin(), ctx->output_ids.end(), -1);
18877-
for (int32_t i = 0; i < n_outputs; ++i) {
18878-
ctx->output_ids[out_ids[i]] = i;
18879-
}
18880-
out_ids.clear();
18881-
}
18882-
}
18883-
1888418888
float * llama_get_logits(struct llama_context * ctx) {
1888518889
llama_synchronize(ctx);
1888618890

0 commit comments

Comments
 (0)