Skip to content

Commit 49f7a3c

Browse files
committed
tests : add unified cache server tests
1 parent 7c7f3bf commit 49f7a3c

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

tools/server/tests/unit/test_completion.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,37 @@ def check_slots_status():
368368
# assert match_regex(re_content, res.body["content"])
369369

370370

371+
@pytest.mark.parametrize(
372+
"n_ctx,n_slots,n_predict_vals,expected_success",
373+
[
374+
(256, 4, [80, 40, 80, 80], [True, True, True, True]),
375+
(256, 4, [70, 70, 70, 70], [False, False, False, False]),
376+
(256, 4, [90, 90, 40, 90], [False, False, True, False]),
377+
(256, 4, [90, 90, 40, 80], [True, True, True, True]),
378+
],
379+
)
380+
def test_completion_unified(n_ctx, n_slots, n_predict_vals, expected_success):
381+
global server
382+
server.n_slots = n_slots
383+
server.kv_unified = True
384+
server.n_ctx = n_ctx
385+
server.start()
386+
prompt = "A"
387+
tasks = []
388+
for n_predict in n_predict_vals:
389+
tasks.append((server.make_request, ("POST", "/completion", {"prompt": prompt, "n_predict": n_predict})))
390+
results = parallel_function_calls(tasks)
391+
for res, n_predict, expect_ok in zip(results, n_predict_vals, expected_success):
392+
if expect_ok:
393+
assert res.status_code == 200
394+
assert "content" in res.body
395+
if "timings" in res.body:
396+
assert res.body["timings"]["predicted_n"] == n_predict
397+
else:
398+
assert res.status_code == 500
399+
assert "content" not in res.body
400+
401+
371402
@pytest.mark.parametrize(
372403
"prompt,n_predict,response_fields",
373404
[

tools/server/tests/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ class ServerProcess:
7878
server_embeddings: bool | None = False
7979
server_reranking: bool | None = False
8080
server_metrics: bool | None = False
81+
kv_unified: bool | None = False
8182
server_slots: bool | None = False
8283
pooling: str | None = None
8384
draft: int | None = None
@@ -159,6 +160,8 @@ def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None:
159160
server_args.append("--reranking")
160161
if self.server_metrics:
161162
server_args.append("--metrics")
163+
if self.kv_unified:
164+
server_args.append("--kv-unified")
162165
if self.server_slots:
163166
server_args.append("--slots")
164167
else:

0 commit comments

Comments
 (0)