From dc290c8d9f9b6aae0954ef79686f7c5eb3f852e6 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 7 Nov 2025 15:56:36 +0200 Subject: [PATCH 1/2] server : handle failures to restore host cache --- tools/server/server.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 164e8cf4e7084..dc3889448fbfa 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1690,6 +1690,9 @@ struct server_slot { bool res = prompt_cache.load(prompt, tokens, ctx, id); if (!res) { SLT_WRN(*this, "%s", "failed to load prompt from cache\n"); + + llama_memory_seq_rm(llama_get_memory(ctx), id, -1, -1); + prompt.tokens.clear(); } } From 0d38370473335617258fef60d6593c7f64980c45 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 8 Nov 2025 10:42:51 +0200 Subject: [PATCH 2/2] server : add tests for the prompt cache --- tools/server/tests/unit/test_completion.py | 42 ++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/tools/server/tests/unit/test_completion.py b/tools/server/tests/unit/test_completion.py index 3c0ce98973f4b..ef1757db21f7f 100644 --- a/tools/server/tests/unit/test_completion.py +++ b/tools/server/tests/unit/test_completion.py @@ -1,6 +1,8 @@ import pytest import requests import time +import random + from openai import OpenAI from utils import * @@ -564,3 +566,43 @@ def test_cancel_request(): time.sleep(1) # wait for HTTP_POLLING_SECONDS res = server.make_request("GET", "/slots") assert res.body[0]["is_processing"] == False + + +# this test exercises the host-memory prompt cache +# ref: https://github.com/ggml-org/llama.cpp/pull/16391 +# ref: https://github.com/ggml-org/llama.cpp/pull/17078 +def test_completion_prompt_cache(): + global server + server.n_slots = 2 + server.kv_unified = True + server.start() + + for _ in range(16): + # generate alternating random prompts with variable lengths in order to get them in and out of the cache + r = random.randint(0, 4) + prompt = (" Hello " + str(r)) * (40 + r) + n_prompt = (40 + r)*5 + 2 + n_predict = random.randint(1, 8) + + res = server.make_request( + "POST", + "/completion", + data={ + "prompt": prompt, + "n_predict": n_predict, + }, + ) + + assert res.status_code == 200 + assert "content" in res.body + content = res.body["content"] + assert isinstance(content, str) + assert len(content) > 0 + + assert type(res.body["has_new_line"]) == bool + assert "timings" in res.body + timings = res.body["timings"] + + assert "prompt_n" in timings and timings["prompt_n"] + timings["cache_n"] == n_prompt + assert "predicted_n" in timings and timings["predicted_n"] == n_predict + assert "tokens" in res.body and isinstance(res.body["tokens"], list)