File tree Expand file tree Collapse file tree 2 files changed +34
-3
lines changed Expand file tree Collapse file tree 2 files changed +34
-3
lines changed Original file line number Diff line number Diff line change @@ -677,12 +677,14 @@ extern "C" {
677677
678678 // Returns the smallest position present in the KV cache for the specified sequence
679679 // This is typically non-zero only for SWA caches
680+ // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
680681 // Return -1 if the sequence is empty
681682 LLAMA_API llama_pos llama_kv_self_seq_pos_min (
682683 struct llama_context * ctx,
683684 llama_seq_id seq_id);
684685
685686 // Returns the largest position present in the KV cache for the specified sequence
687+ // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
686688 // Return -1 if the sequence is empty
687689 LLAMA_API llama_pos llama_kv_self_seq_pos_max (
688690 struct llama_context * ctx,
Original file line number Diff line number Diff line change @@ -951,19 +951,48 @@ int llama_context::decode(llama_batch & inp_batch) {
951951
952952 res->set_inputs (&ubatch);
953953
954+ int ret = 0 ;
955+
954956 const auto compute_status = graph_compute (gf, ubatch.n_tokens > 1 );
955957 if (compute_status != GGML_STATUS_SUCCESS) {
956958 switch (compute_status) {
957959 case GGML_STATUS_ABORTED:
958- return 2 ;
960+ {
961+ ret = 2 ;
962+ } break ;
959963 case GGML_STATUS_ALLOC_FAILED:
960- return -2 ;
964+ {
965+ ret = -2 ;
966+ } break ;
961967 case GGML_STATUS_FAILED:
962968 default :
963- return -3 ;
969+ {
970+ ret = -3 ;
971+ }
964972 }
965973 }
966974
975+ if (ret != 0 ) {
976+ // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
977+ llama_pos pos_min[LLAMA_MAX_PARALLEL_SEQUENCES] = { std::numeric_limits<llama_pos>::max () };
978+
979+ for (uint32_t i = 0 ; i < ubatch.n_tokens ; ++i) {
980+ const auto & seq_id = ubatch.seq_id [i][0 ];
981+
982+ pos_min[seq_id] = std::min (pos_min[seq_id], ubatch.pos [i]);
983+ }
984+
985+ for (int s = 0 ; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
986+ if (pos_min[s] == std::numeric_limits<llama_pos>::max ()) {
987+ continue ;
988+ }
989+
990+ llama_kv_self_seq_rm (this , s, pos_min[s], -1 );
991+ }
992+
993+ return ret;
994+ }
995+
967996 // plot the computation graph in dot format (for debugging purposes)
968997 // if (n_past%100 == 0) {
969998 // ggml_graph_dump_dot(gf, NULL, "llama.dot");
You can’t perform that action at this time.
0 commit comments