File tree Expand file tree Collapse file tree 2 files changed +36
-4
lines changed Expand file tree Collapse file tree 2 files changed +36
-4
lines changed Original file line number Diff line number Diff line change @@ -677,12 +677,14 @@ extern "C" {
677677
678678 // Returns the smallest position present in the KV cache for the specified sequence
679679 // This is typically non-zero only for SWA caches
680+ // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
680681 // Return -1 if the sequence is empty
681682 LLAMA_API llama_pos llama_kv_self_seq_pos_min (
682683 struct llama_context * ctx,
683684 llama_seq_id seq_id);
684685
685686 // Returns the largest position present in the KV cache for the specified sequence
687+ // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
686688 // Return -1 if the sequence is empty
687689 LLAMA_API llama_pos llama_kv_self_seq_pos_max (
688690 struct llama_context * ctx,
Original file line number Diff line number Diff line change 66#include " llama-model.h"
77#include " llama-kv-cache.h"
88
9+ #include < cinttypes>
10+ #include < climits>
911#include < cstring>
1012#include < stdexcept>
11- #include < cinttypes>
1213
1314//
1415// llama_context
@@ -951,19 +952,48 @@ int llama_context::decode(llama_batch & inp_batch) {
951952
952953 res->set_inputs (&ubatch);
953954
955+ int ret = 0 ;
956+
954957 const auto compute_status = graph_compute (gf, ubatch.n_tokens > 1 );
955958 if (compute_status != GGML_STATUS_SUCCESS) {
956959 switch (compute_status) {
957960 case GGML_STATUS_ABORTED:
958- return 2 ;
961+ {
962+ ret = 2 ;
963+ } break ;
959964 case GGML_STATUS_ALLOC_FAILED:
960- return -2 ;
965+ {
966+ ret = -2 ;
967+ } break ;
961968 case GGML_STATUS_FAILED:
962969 default :
963- return -3 ;
970+ {
971+ ret = -3 ;
972+ }
964973 }
965974 }
966975
976+ if (ret != 0 ) {
977+ // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
978+ llama_pos pos_min[LLAMA_MAX_PARALLEL_SEQUENCES] = { std::numeric_limits<llama_pos>::max () };
979+
980+ for (uint32_t i = 0 ; i < ubatch.n_tokens ; ++i) {
981+ const auto & seq_id = ubatch.seq_id [i][0 ];
982+
983+ pos_min[seq_id] = std::min (pos_min[seq_id], ubatch.pos [i]);
984+ }
985+
986+ for (int s = 0 ; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
987+ if (pos_min[s] == std::numeric_limits<llama_pos>::max ()) {
988+ continue ;
989+ }
990+
991+ llama_kv_self_seq_rm (this , s, pos_min[s], -1 );
992+ }
993+
994+ return ret;
995+ }
996+
967997 // plot the computation graph in dot format (for debugging purposes)
968998 // if (n_past%100 == 0) {
969999 // ggml_graph_dump_dot(gf, NULL, "llama.dot");
You can’t perform that action at this time.
0 commit comments