Skip to content

Commit e721f4c

Browse files
committed
add tests for /props and /slots
1 parent 6bf6e30 commit e721f4c

File tree

2 files changed

+22
-5
lines changed

2 files changed

+22
-5
lines changed

examples/server/tests/unit/test_basic.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ def test_server_props():
2323
res = server.make_request("GET", "/props")
2424
assert res.status_code == 200
2525
assert res.body["total_slots"] == server.n_slots
26+
default_val = res.body["default_generation_settings"]
27+
assert server.n_ctx is not None and server.n_slots is not None
28+
assert default_val["n_ctx"] == server.n_ctx / server.n_slots
29+
assert default_val["params"]["seed"] == server.seed
2630

2731

2832
def test_server_models():
@@ -36,12 +40,26 @@ def test_server_models():
3640

3741
def test_server_slots():
3842
global server
43+
44+
# without slots endpoint enabled, this should return error
45+
server.server_slots = False
46+
server.start()
47+
res = server.make_request("GET", "/slots")
48+
assert res.status_code == 501 # ERROR_TYPE_NOT_SUPPORTED
49+
assert "error" in res.body
50+
server.stop()
51+
52+
# with slots endpoint enabled, this should return slots info
3953
server.server_slots = True
54+
server.n_slots = 2
4055
server.start()
4156
res = server.make_request("GET", "/slots")
4257
assert res.status_code == 200
4358
assert len(res.body) == server.n_slots
44-
assert res.body[0]["n_ctx"] > 0
59+
assert server.n_ctx is not None and server.n_slots is not None
60+
assert res.body[0]["n_ctx"] == server.n_ctx / server.n_slots
61+
assert "params" in res.body[0]
62+
assert res.body[0]["params"]["seed"] == server.seed
4563

4664

4765
def test_load_split_model():

examples/server/tests/utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@ def start(self, timeout_seconds: int = 10) -> None:
9292
else:
9393
server_path = "../../../build/bin/llama-server"
9494
server_args = [
95-
"--slots", # requires to get slot status via /slots endpoint
9695
"--host",
9796
self.server_host,
9897
"--port",
@@ -184,7 +183,7 @@ def start(self, timeout_seconds: int = 10) -> None:
184183
start_time = time.time()
185184
while time.time() - start_time < timeout_seconds:
186185
try:
187-
response = self.make_request("GET", "/slots", headers={
186+
response = self.make_request("GET", "/health", headers={
188187
"Authorization": f"Bearer {self.api_key}" if self.api_key else None
189188
})
190189
if response.status_code == 200:
@@ -227,7 +226,7 @@ def make_request(
227226
result.headers = dict(response.headers)
228227
result.status_code = response.status_code
229228
result.body = response.json() if parse_body else None
230-
print("Response from server", result.body)
229+
print("Response from server", json.dumps(result.body, indent=2))
231230
return result
232231

233232
def make_stream_request(
@@ -248,7 +247,7 @@ def make_stream_request(
248247
break
249248
elif line.startswith('data: '):
250249
data = json.loads(line[6:])
251-
print("Partial response from server", data)
250+
print("Partial response from server", json.dumps(data, indent=2))
252251
yield data
253252

254253

0 commit comments

Comments
 (0)