Skip to content

Commit 5b73d10

Browse files
author
Anivar A Aravind
committed
Fix critical KV cache crash bug causing std::length_error
Fixes integer underflow when n_discard >= cache_tokens.size() that causes std::length_error crashes. This commonly occurs during KV cache context shifting, particularly with Chinese text translation workloads. The fix adds proper bounds checking before resizing the cache_tokens vector. Fixes #771
1 parent 24f18f7 commit 5b73d10

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

llama.cpp/server/server.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1711,7 +1711,12 @@ struct llama_server_context
17111711
slot.cache_tokens[i - n_discard] = slot.cache_tokens[i];
17121712
}
17131713

1714-
slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard);
1714+
// Prevent integer underflow that causes std::length_error
1715+
if (n_discard >= 0 && (size_t)n_discard < slot.cache_tokens.size()) {
1716+
slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard);
1717+
} else {
1718+
slot.cache_tokens.clear();
1719+
}
17151720

17161721
slot.n_past -= n_discard;
17171722

0 commit comments

Comments
 (0)