@@ -2815,6 +2815,22 @@ struct llama_kv_cache {
28152815 }
28162816};
28172817
2818+ // saves the kv_cache state for future recovery
2819+ // used to preserve the kv_cache state before searching for a slot
2820+ struct llama_kv_slot_restorer {
2821+ struct llama_kv_cache_state {
2822+ uint32_t head = 0;
2823+ uint32_t size = 0;
2824+ uint32_t used = 0;
2825+ uint32_t n = 0;
2826+ } old_state;
2827+
2828+ std::vector<llama_kv_cell> recurrent_cells; // for recurrent models only
2829+ std::pair<uint32_t, uint32_t> slot_boundaries; // for non-recurrent models only
2830+
2831+ bool restore = false;
2832+ };
2833+
28182834struct llama_control_vector {
28192835 std::vector<struct ggml_tensor *> tensors; // per layer
28202836 std::vector<struct ggml_context *> ctxs;
@@ -3652,11 +3668,19 @@ static bool llama_kv_cache_init(
36523668// to the first cell of the slot.
36533669static bool llama_kv_cache_find_slot(
36543670 struct llama_kv_cache & cache,
3655- const struct llama_ubatch & batch) {
3671+ const struct llama_ubatch & batch,
3672+ struct llama_kv_slot_restorer * slot_restorer = nullptr) {
36563673 const uint32_t n_tokens = batch.n_tokens;
36573674 const uint32_t n_seqs = batch.n_seqs;
36583675 const uint32_t n_seq_tokens = batch.n_seq_tokens;
36593676
3677+ if (slot_restorer != nullptr) {
3678+ slot_restorer->old_state.head = cache.head;
3679+ slot_restorer->old_state.size = cache.size;
3680+ slot_restorer->old_state.used = cache.used;
3681+ slot_restorer->old_state.n = cache.n;
3682+ }
3683+
36603684 if (cache.recurrent) {
36613685 // For recurrent state architectures (like Mamba or RWKV),
36623686 // each cache cell can store the state for a whole sequence.
@@ -3665,6 +3689,11 @@ static bool llama_kv_cache_find_slot(
36653689 // can only process batches with an equal number of new tokens in each sequence
36663690 GGML_ASSERT(batch.equal_seqs);
36673691
3692+ if (slot_restorer != nullptr) {
3693+ slot_restorer->recurrent_cells = cache.cells;
3694+ slot_restorer->restore = true;
3695+ }
3696+
36683697 int32_t min = cache.size - 1;
36693698 int32_t max = 0;
36703699
@@ -3853,6 +3882,11 @@ static bool llama_kv_cache_find_slot(
38533882 }
38543883 }
38553884
3885+ if (slot_restorer != nullptr) {
3886+ slot_restorer->slot_boundaries = std::make_pair(cache.head, cache.head + n_tokens);
3887+ slot_restorer->restore = true;
3888+ }
3889+
38563890 for (uint32_t s = 0; s < n_seqs; s++) {
38573891 for (uint32_t i = 0; i < n_seq_tokens; ++i) {
38583892 uint32_t k = s*n_seq_tokens + i;
@@ -4142,6 +4176,23 @@ static uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams)
41424176 return cparams.flash_attn ? 256u : 32u;
41434177}
41444178
4179+ static void llama_kv_cache_slot_restore(
4180+ const struct llama_kv_slot_restorer & restorer,
4181+ struct llama_kv_cache & cache) {
4182+ if (restorer.restore) {
4183+ cache.head = restorer.old_state.head;
4184+ cache.size = restorer.old_state.size;
4185+ cache.used = restorer.old_state.used;
4186+ cache.n = restorer.old_state.n;
4187+
4188+ if (cache.recurrent) {
4189+ cache.cells = restorer.recurrent_cells;
4190+ } else {
4191+ llama_kv_cache_seq_rm(cache, -1, restorer.slot_boundaries.first, restorer.slot_boundaries.second + 1);
4192+ }
4193+ }
4194+ }
4195+
41454196//
41464197// model loading and saving
41474198//
@@ -17184,6 +17235,7 @@ static int llama_decode_internal(
1718417235 lctx.n_queued_tokens += n_tokens_all;
1718517236
1718617237 auto & kv_self = lctx.kv_self;
17238+ llama_kv_slot_restorer kv_slot_restorer;
1718717239
1718817240 const int64_t n_embd = hparams.n_embd;
1718917241 const int64_t n_vocab = hparams.n_vocab;
@@ -17268,7 +17320,7 @@ static int llama_decode_internal(
1726817320 kv_self.head = 0;
1726917321 }
1727017322
17271- if (!llama_kv_cache_find_slot(kv_self, ubatch)) {
17323+ if (!llama_kv_cache_find_slot(kv_self, ubatch, &kv_slot_restorer )) {
1727217324 return 1;
1727317325 }
1727417326
@@ -17318,16 +17370,17 @@ static int llama_decode_internal(
1731817370 llama_set_inputs(lctx, ubatch);
1731917371
1732017372 const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
17321- switch (compute_status) {
17322- case GGML_STATUS_SUCCESS:
17323- break;
17324- case GGML_STATUS_ABORTED:
17325- return 2;
17326- case GGML_STATUS_ALLOC_FAILED:
17327- return -2;
17328- case GGML_STATUS_FAILED:
17329- default:
17330- return -3;
17373+ if (compute_status != GGML_STATUS_SUCCESS) {
17374+ llama_kv_cache_slot_restore(kv_slot_restorer, kv_self);
17375+ switch (compute_status) {
17376+ case GGML_STATUS_ABORTED:
17377+ return 2;
17378+ case GGML_STATUS_ALLOC_FAILED:
17379+ return -2;
17380+ case GGML_STATUS_FAILED:
17381+ default:
17382+ return -3;
17383+ }
1733117384 }
1733217385
1733317386 // update the kv ring buffer
0 commit comments