Skip to content

Commit 3acaf58

Browse files
committed
server : replace behave with pytest
1 parent 42ae10b commit 3acaf58

File tree

6 files changed

+298
-4
lines changed

6 files changed

+298
-4
lines changed

examples/server/tests/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
.venv
2+
tmp

examples/server/tests/conftest.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import pytest
2+
from utils import *
3+
4+
5+
# ref: https://stackoverflow.com/questions/22627659/run-code-before-and-after-each-test-in-py-test
6+
@pytest.fixture(scope="session", autouse=True)
7+
def stop_server_after_each_test():
8+
# do nothing before each test
9+
yield
10+
# stop all servers after each test
11+
instances = set(
12+
server_instances
13+
) # copy the set to prevent 'Set changed size during iteration'
14+
for server in instances:
15+
server.stop()

examples/server/tests/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
aiohttp~=3.9.3
2-
behave~=1.2.6
2+
pytest~=8.3.3
33
huggingface_hub~=0.23.2
44
numpy~=1.26.4
55
openai~=1.30.3

examples/server/tests/tests.sh

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@ set -eu
44

55
if [ $# -lt 1 ]
66
then
7-
# Start @llama.cpp scenario
8-
behave --summary --stop --no-capture --exclude 'issues|wrong_usages|passkey' --tags llama.cpp
7+
pytest -v -s
98
else
10-
behave "$@"
9+
pytest "$@"
1110
fi
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import pytest
2+
from utils import *
3+
4+
server = ServerProcess()
5+
6+
7+
@pytest.fixture(scope="module", autouse=True)
8+
def create_server():
9+
global server
10+
server = ServerProcess()
11+
server.model_hf_repo = "ggml-org/models"
12+
server.model_hf_file = "tinyllamas/stories260K.gguf"
13+
server.n_ctx = 256
14+
server.n_batch = 32
15+
server.n_slots = 2
16+
server.n_predict = 64
17+
18+
19+
def test_server_start_simple():
20+
global server
21+
server.start()
22+
res = server.make_request("GET", "/health")
23+
assert res.status_code == 200
24+
25+
26+
def test_server_props():
27+
global server
28+
server.start()
29+
res = server.make_request("GET", "/props")
30+
assert res.status_code == 200
31+
assert res.body["total_slots"] == server.n_slots

examples/server/tests/utils.py

Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
4+
import subprocess
5+
import os
6+
import sys
7+
import threading
8+
import requests
9+
import time
10+
from typing import (
11+
Any,
12+
Callable,
13+
ContextManager,
14+
Iterable,
15+
Iterator,
16+
Literal,
17+
Sequence,
18+
Set,
19+
)
20+
21+
22+
class ServerResponse:
23+
headers: dict
24+
status_code: int
25+
body: dict
26+
27+
28+
class ServerProcess:
29+
# default options
30+
debug: bool = False
31+
server_port: int = 8080
32+
server_host: str = "127.0.0.1"
33+
model_hf_repo: str = "ggml-org/models"
34+
model_hf_file: str = "tinyllamas/stories260K.gguf"
35+
36+
# custom options
37+
model_alias: str | None = None
38+
model_url: str | None = None
39+
model_file: str | None = None
40+
n_threads: int | None = None
41+
n_gpu_layer: str | None = None
42+
n_batch: int | None = None
43+
n_ubatch: int | None = None
44+
n_ctx: int | None = None
45+
n_ga: int | None = None
46+
n_ga_w: int | None = None
47+
n_predict: int | None = None
48+
n_prompts: int | None = 0
49+
n_server_predict: int | None = None
50+
slot_save_path: str | None = None
51+
id_slot: int | None = None
52+
cache_prompt: bool | None = None
53+
n_slots: int | None = None
54+
server_api_key: str | None = None
55+
server_continuous_batching: bool | None = False
56+
server_embeddings: bool | None = False
57+
server_reranking: bool | None = False
58+
server_metrics: bool | None = False
59+
seed: int | None = None
60+
draft: int | None = None
61+
server_seed: int | None = None
62+
user_api_key: str | None = None
63+
response_format: str | None = None
64+
temperature: float | None = None
65+
lora_file: str | None = None
66+
disable_ctx_shift: int | None = False
67+
68+
# session variables
69+
process: subprocess.Popen | None = None
70+
71+
def __init__(self):
72+
pass
73+
74+
def start(self, timeout_seconds: int = 10) -> None:
75+
if "LLAMA_SERVER_BIN_PATH" in os.environ:
76+
server_path = os.environ["LLAMA_SERVER_BIN_PATH"]
77+
elif os.name == "nt":
78+
server_path = "../../../build/bin/Release/llama-server.exe"
79+
else:
80+
server_path = "../../../build/bin/llama-server"
81+
server_args = [
82+
"--slots", # requires to get slot status via /slots endpoint
83+
"--host",
84+
self.server_host,
85+
"--port",
86+
self.server_port,
87+
]
88+
if self.model_file:
89+
server_args.extend(["--model", self.model_file])
90+
if self.model_url:
91+
server_args.extend(["--model-url", self.model_url])
92+
if self.model_hf_repo:
93+
server_args.extend(["--hf-repo", self.model_hf_repo])
94+
if self.model_hf_file:
95+
server_args.extend(["--hf-file", self.model_hf_file])
96+
if self.n_batch:
97+
server_args.extend(["--batch-size", self.n_batch])
98+
if self.n_ubatch:
99+
server_args.extend(["--ubatch-size", self.n_ubatch])
100+
if self.n_threads:
101+
server_args.extend(["--threads", self.n_threads])
102+
if self.n_gpu_layer:
103+
server_args.extend(["--n-gpu-layers", self.n_gpu_layer])
104+
if self.draft is not None:
105+
server_args.extend(["--draft", self.draft])
106+
if self.server_continuous_batching:
107+
server_args.append("--cont-batching")
108+
if self.server_embeddings:
109+
server_args.append("--embedding")
110+
if self.server_reranking:
111+
server_args.append("--reranking")
112+
if self.server_metrics:
113+
server_args.append("--metrics")
114+
if self.model_alias:
115+
server_args.extend(["--alias", self.model_alias])
116+
if self.n_ctx:
117+
server_args.extend(["--ctx-size", self.n_ctx])
118+
if self.n_slots:
119+
server_args.extend(["--parallel", self.n_slots])
120+
if self.n_server_predict:
121+
server_args.extend(["--n-predict", self.n_server_predict])
122+
if self.slot_save_path:
123+
server_args.extend(["--slot-save-path", self.slot_save_path])
124+
if self.server_api_key:
125+
server_args.extend(["--api-key", self.server_api_key])
126+
if self.n_ga:
127+
server_args.extend(["--grp-attn-n", self.n_ga])
128+
if self.n_ga_w:
129+
server_args.extend(["--grp-attn-w", self.n_ga_w])
130+
if self.debug:
131+
server_args.append("--verbose")
132+
if self.lora_file:
133+
server_args.extend(["--lora", self.lora_file])
134+
if self.disable_ctx_shift:
135+
server_args.extend(["--no-context-shift"])
136+
137+
args = [str(arg) for arg in [server_path, *server_args]]
138+
print(f"bench: starting server with: {' '.join(args)}")
139+
140+
flags = 0
141+
if "nt" == os.name:
142+
flags |= subprocess.DETACHED_PROCESS
143+
flags |= subprocess.CREATE_NEW_PROCESS_GROUP
144+
flags |= subprocess.CREATE_NO_WINDOW
145+
146+
self.process = subprocess.Popen(
147+
[str(arg) for arg in [server_path, *server_args]],
148+
creationflags=flags,
149+
stdout=subprocess.PIPE,
150+
stderr=subprocess.PIPE,
151+
env={**os.environ, "LLAMA_CACHE": "tmp"},
152+
)
153+
server_instances.add(self)
154+
155+
def server_log(in_stream, out_stream):
156+
for line in iter(in_stream.readline, b""):
157+
print(line.decode("utf-8"), end="", file=out_stream)
158+
159+
thread_stdout = threading.Thread(
160+
target=server_log, args=(self.process.stdout, sys.stdout), daemon=True
161+
)
162+
thread_stdout.start()
163+
164+
thread_stderr = threading.Thread(
165+
target=server_log, args=(self.process.stderr, sys.stderr), daemon=True
166+
)
167+
thread_stderr.start()
168+
169+
print(f"server pid={self.process.pid}, behave pid={os.getpid()}")
170+
171+
# wait for server to start
172+
start_time = time.time()
173+
while time.time() - start_time < timeout_seconds:
174+
try:
175+
response = self.make_request("GET", "/slots")
176+
if response.status_code == 200:
177+
self.ready = True
178+
return # server is ready
179+
except Exception as e:
180+
pass
181+
print(f"Waiting for server to start...")
182+
time.sleep(0.5)
183+
raise TimeoutError(f"Server did not start within {timeout_seconds} seconds")
184+
185+
def stop(self) -> None:
186+
server_instances.remove(self)
187+
if self.process:
188+
print(f"Stopping server with pid={self.process.pid}")
189+
self.process.kill()
190+
self.process = None
191+
192+
def make_request(
193+
self,
194+
method: str,
195+
path: str,
196+
data: dict | None = None,
197+
headers: dict | None = None,
198+
) -> ServerResponse:
199+
url = f"http://{self.server_host}:{self.server_port}{path}"
200+
headers = {}
201+
if self.user_api_key:
202+
headers["Authorization"] = f"Bearer {self.user_api_key}"
203+
if self.response_format:
204+
headers["Accept"] = self.response_format
205+
if method == "GET":
206+
response = requests.get(url, headers=headers)
207+
elif method == "POST":
208+
response = requests.post(url, headers=headers, json=data)
209+
elif method == "OPTIONS":
210+
response = requests.options(url, headers=headers)
211+
else:
212+
raise ValueError(f"Unimplemented method: {method}")
213+
result = ServerResponse()
214+
result.headers = dict(response.headers)
215+
result.status_code = response.status_code
216+
result.body = response.json()
217+
return result
218+
219+
220+
server_instances: Set[ServerProcess] = set()
221+
222+
223+
def multiple_post_requests(
224+
server: ServerProcess, path: str, data: Sequence[dict], headers: dict | None = None
225+
) -> Sequence[ServerResponse]:
226+
def worker(data_chunk):
227+
try:
228+
return server.make_request("POST", path, data=data_chunk, headers=headers)
229+
except Exception as e:
230+
print(f"Error occurred: {e}", file=sys.stderr)
231+
os._exit(1) # terminate main thread
232+
233+
threads = []
234+
results = []
235+
236+
def thread_target(data_chunk):
237+
result = worker(data_chunk)
238+
results.append(result)
239+
240+
for chunk in data:
241+
thread = threading.Thread(target=thread_target, args=(chunk,))
242+
threads.append(thread)
243+
thread.start()
244+
245+
for thread in threads:
246+
thread.join()
247+
248+
return results

0 commit comments

Comments
 (0)