Skip to content

Commit 3249aab

Browse files
committed
add more tests
1 parent d7de413 commit 3249aab

File tree

7 files changed

+196
-18
lines changed

7 files changed

+196
-18
lines changed

.github/workflows/server.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,12 +180,12 @@ jobs:
180180
run: |
181181
cd examples/server/tests
182182
$env:PYTHONIOENCODING = ":replace"
183-
pytest -v -s
183+
pytest -v -s -x
184184
185185
- name: Slow tests
186186
id: server_integration_tests_slow
187187
if: ${{ (github.event.schedule || github.event.inputs.slow_tests == 'true') && matrix.build_type == 'Release' }}
188188
run: |
189189
cd examples/server/tests
190190
$env:SLOW_TESTS = "1"
191-
pytest -v -s
191+
pytest -v -s -x

examples/server/tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44

55
# ref: https://stackoverflow.com/questions/22627659/run-code-before-and-after-each-test-in-py-test
6-
@pytest.fixture(scope="session", autouse=True)
6+
@pytest.fixture(autouse=True)
77
def stop_server_after_each_test():
88
# do nothing before each test
99
yield

examples/server/tests/tests.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ set -eu
44

55
if [ $# -lt 1 ]
66
then
7-
pytest -v -s
7+
pytest -v -s -x
88
else
99
pytest "$@"
1010
fi

examples/server/tests/unit/test_basic.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,13 @@
11
import pytest
22
from utils import *
33

4-
server = ServerProcess()
4+
server = ServerPreset.tinyllamas()
55

66

77
@pytest.fixture(scope="module", autouse=True)
88
def create_server():
99
global server
10-
server = ServerProcess()
11-
server.model_hf_repo = "ggml-org/models"
12-
server.model_hf_file = "tinyllamas/stories260K.gguf"
13-
server.n_ctx = 256
14-
server.n_batch = 32
15-
server.n_slots = 2
16-
server.n_predict = 64
10+
server = ServerPreset.tinyllamas()
1711

1812

1913
def test_server_start_simple():
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import pytest
2+
from openai import OpenAI
3+
from utils import *
4+
5+
server = ServerPreset.tinyllamas()
6+
7+
8+
@pytest.fixture(scope="module", autouse=True)
9+
def create_server():
10+
global server
11+
server = ServerPreset.tinyllamas()
12+
13+
14+
@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [
15+
("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False),
16+
("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False),
17+
])
18+
def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool):
19+
global server
20+
server.start()
21+
res = server.make_request("POST", "/completion", data={
22+
"n_predict": n_predict,
23+
"prompt": prompt,
24+
})
25+
assert res.status_code == 200
26+
assert res.body["timings"]["prompt_n"] == n_prompt
27+
assert res.body["timings"]["predicted_n"] == n_predicted
28+
assert res.body["truncated"] == truncated
29+
assert match_regex(re_content, res.body["content"])
30+
31+
32+
@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [
33+
("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False),
34+
("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False),
35+
])
36+
def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool):
37+
global server
38+
server.start()
39+
res = server.make_stream_request("POST", "/completion", data={
40+
"n_predict": n_predict,
41+
"prompt": prompt,
42+
"stream": True,
43+
})
44+
content = ""
45+
for data in res:
46+
if data["stop"]:
47+
assert data["timings"]["prompt_n"] == n_prompt
48+
assert data["timings"]["predicted_n"] == n_predicted
49+
assert data["truncated"] == truncated
50+
assert match_regex(re_content, content)
51+
else:
52+
content += data["content"]
53+
54+
55+
# FIXME: This test is not working because /completions endpoint is not OAI-compatible
56+
@pytest.mark.skip(reason="Only /chat/completions is OAI-compatible for now")
57+
def test_completion_with_openai_library():
58+
global server
59+
server.start()
60+
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
61+
res = client.completions.create(
62+
model="gpt-3.5-turbo-instruct",
63+
prompt="I believe the meaning of life is",
64+
n=8,
65+
seed=42,
66+
temperature=0.8,
67+
)
68+
print(res)
69+
assert res.choices[0].finish_reason == "length"
70+
assert match_regex("(going|bed)+", res.choices[0].text)
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import pytest
2+
from utils import *
3+
4+
server = ServerPreset.tinyllamas()
5+
6+
7+
@pytest.fixture(scope="module", autouse=True)
8+
def create_server():
9+
global server
10+
server = ServerPreset.tinyllamas()
11+
12+
13+
def test_tokenize_detokenize():
14+
global server
15+
server.start()
16+
# tokenize
17+
content = "What is the capital of France ?"
18+
resTok = server.make_request("POST", "/tokenize", data={
19+
"content": content
20+
})
21+
assert resTok.status_code == 200
22+
assert len(resTok.body["tokens"]) > 5
23+
# detokenize
24+
resDetok = server.make_request("POST", "/detokenize", data={
25+
"tokens": resTok.body["tokens"],
26+
})
27+
assert resDetok.status_code == 200
28+
assert resDetok.body["content"].strip() == content
29+
30+
31+
def test_tokenize_with_bos():
32+
global server
33+
server.start()
34+
# tokenize
35+
content = "What is the capital of France ?"
36+
bosId = 1
37+
resTok = server.make_request("POST", "/tokenize", data={
38+
"content": content,
39+
"add_special": True,
40+
})
41+
assert resTok.status_code == 200
42+
assert resTok.body["tokens"][0] == bosId
43+
44+
45+
def test_tokenize_with_pieces():
46+
global server
47+
server.start()
48+
# tokenize
49+
content = "This is a test string with unicode 媽 and emoji 🤗"
50+
resTok = server.make_request("POST", "/tokenize", data={
51+
"content": content,
52+
"with_pieces": True,
53+
})
54+
assert resTok.status_code == 200
55+
for token in resTok.body["tokens"]:
56+
assert "id" in token
57+
assert token["id"] > 0
58+
assert "piece" in token
59+
assert len(token["piece"]) > 0

examples/server/tests/utils.py

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
import subprocess
77
import os
8+
import re
9+
import json
810
import sys
911
import threading
1012
import requests
@@ -19,6 +21,7 @@
1921
Sequence,
2022
Set,
2123
)
24+
from re import RegexFlag
2225

2326

2427
class ServerResponse:
@@ -34,6 +37,9 @@ class ServerProcess:
3437
server_host: str = "127.0.0.1"
3538
model_hf_repo: str = "ggml-org/models"
3639
model_hf_file: str = "tinyllamas/stories260K.gguf"
40+
model_alias: str = "tinyllama-2"
41+
temperature: float = 0.8
42+
seed: int = 42
3743

3844
# custom options
3945
model_alias: str | None = None
@@ -48,7 +54,6 @@ class ServerProcess:
4854
n_ga_w: int | None = None
4955
n_predict: int | None = None
5056
n_prompts: int | None = 0
51-
n_server_predict: int | None = None
5257
slot_save_path: str | None = None
5358
id_slot: int | None = None
5459
cache_prompt: bool | None = None
@@ -58,12 +63,9 @@ class ServerProcess:
5863
server_embeddings: bool | None = False
5964
server_reranking: bool | None = False
6065
server_metrics: bool | None = False
61-
seed: int | None = None
6266
draft: int | None = None
63-
server_seed: int | None = None
6467
user_api_key: str | None = None
6568
response_format: str | None = None
66-
temperature: float | None = None
6769
lora_file: str | None = None
6870
disable_ctx_shift: int | None = False
6971

@@ -86,6 +88,10 @@ def start(self, timeout_seconds: int = 10) -> None:
8688
self.server_host,
8789
"--port",
8890
self.server_port,
91+
"--temp",
92+
self.temperature,
93+
"--seed",
94+
self.seed,
8995
]
9096
if self.model_file:
9197
server_args.extend(["--model", self.model_file])
@@ -119,8 +125,8 @@ def start(self, timeout_seconds: int = 10) -> None:
119125
server_args.extend(["--ctx-size", self.n_ctx])
120126
if self.n_slots:
121127
server_args.extend(["--parallel", self.n_slots])
122-
if self.n_server_predict:
123-
server_args.extend(["--n-predict", self.n_server_predict])
128+
if self.n_predict:
129+
server_args.extend(["--n-predict", self.n_predict])
124130
if self.slot_save_path:
125131
server_args.extend(["--slot-save-path", self.slot_save_path])
126132
if self.server_api_key:
@@ -216,12 +222,52 @@ def make_request(
216222
result.headers = dict(response.headers)
217223
result.status_code = response.status_code
218224
result.body = response.json()
225+
print("Response from server", result.body)
219226
return result
227+
228+
def make_stream_request(
229+
self,
230+
method: str,
231+
path: str,
232+
data: dict | None = None,
233+
headers: dict | None = None,
234+
) -> Iterator[dict]:
235+
url = f"http://{self.server_host}:{self.server_port}{path}"
236+
headers = {}
237+
if self.user_api_key:
238+
headers["Authorization"] = f"Bearer {self.user_api_key}"
239+
if method == "POST":
240+
response = requests.post(url, headers=headers, json=data, stream=True)
241+
else:
242+
raise ValueError(f"Unimplemented method: {method}")
243+
for line_bytes in response.iter_lines():
244+
line = line_bytes.decode("utf-8")
245+
if '[DONE]' in line:
246+
break
247+
elif line.startswith('data: '):
248+
data = json.loads(line[6:])
249+
print("Partial response from server", data)
250+
yield data
220251

221252

222253
server_instances: Set[ServerProcess] = set()
223254

224255

256+
class ServerPreset:
257+
@staticmethod
258+
def tinyllamas() -> ServerProcess:
259+
server = ServerProcess()
260+
server.model_hf_repo = "ggml-org/models"
261+
server.model_hf_file = "tinyllamas/stories260K.gguf"
262+
server.model_alias = "tinyllama-2"
263+
server.n_ctx = 256
264+
server.n_batch = 32
265+
server.n_slots = 2
266+
server.n_predict = 64
267+
server.seed = 42
268+
return server
269+
270+
225271
def multiple_post_requests(
226272
server: ServerProcess, path: str, data: Sequence[dict], headers: dict | None = None
227273
) -> Sequence[ServerResponse]:
@@ -248,3 +294,12 @@ def thread_target(data_chunk):
248294
thread.join()
249295

250296
return results
297+
298+
299+
def match_regex(regex: str, text: str) -> bool:
300+
return (
301+
re.compile(
302+
regex, flags=RegexFlag.IGNORECASE | RegexFlag.MULTILINE | RegexFlag.DOTALL
303+
).search(text)
304+
is not None
305+
)

0 commit comments

Comments
 (0)