Skip to content

Commit ab2e275

Browse files
committed
llama : handle aborts and compute errors
ggml-ci
1 parent 2252eef commit ab2e275

File tree

2 files changed

+34
-3
lines changed

2 files changed

+34
-3
lines changed

include/llama.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -677,12 +677,14 @@ extern "C" {
677677

678678
// Returns the smallest position present in the KV cache for the specified sequence
679679
// This is typically non-zero only for SWA caches
680+
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
680681
// Return -1 if the sequence is empty
681682
LLAMA_API llama_pos llama_kv_self_seq_pos_min(
682683
struct llama_context * ctx,
683684
llama_seq_id seq_id);
684685

685686
// Returns the largest position present in the KV cache for the specified sequence
687+
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
686688
// Return -1 if the sequence is empty
687689
LLAMA_API llama_pos llama_kv_self_seq_pos_max(
688690
struct llama_context * ctx,

src/llama-context.cpp

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -951,19 +951,48 @@ int llama_context::decode(llama_batch & inp_batch) {
951951

952952
res->set_inputs(&ubatch);
953953

954+
int ret = 0;
955+
954956
const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1);
955957
if (compute_status != GGML_STATUS_SUCCESS) {
956958
switch (compute_status) {
957959
case GGML_STATUS_ABORTED:
958-
return 2;
960+
{
961+
ret = 2;
962+
} break;
959963
case GGML_STATUS_ALLOC_FAILED:
960-
return -2;
964+
{
965+
ret = -2;
966+
} break;
961967
case GGML_STATUS_FAILED:
962968
default:
963-
return -3;
969+
{
970+
ret = -3;
971+
}
964972
}
965973
}
966974

975+
if (ret != 0) {
976+
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
977+
llama_pos pos_min[LLAMA_MAX_PARALLEL_SEQUENCES] = { std::numeric_limits<llama_pos>::max() };
978+
979+
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
980+
const auto & seq_id = ubatch.seq_id[i][0];
981+
982+
pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]);
983+
}
984+
985+
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
986+
if (pos_min[s] == std::numeric_limits<llama_pos>::max()) {
987+
continue;
988+
}
989+
990+
llama_kv_self_seq_rm(this, s, pos_min[s], -1);
991+
}
992+
993+
return ret;
994+
}
995+
967996
// plot the computation graph in dot format (for debugging purposes)
968997
//if (n_past%100 == 0) {
969998
// ggml_graph_dump_dot(gf, NULL, "llama.dot");

0 commit comments

Comments
 (0)