Skip to content

Commit f09a9b6

Browse files
committed
more tests
1 parent 3249aab commit f09a9b6

File tree

7 files changed

+417
-36
lines changed

7 files changed

+417
-36
lines changed

examples/server/tests/unit/test_basic.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import pytest
22
from utils import *
33

4-
server = ServerPreset.tinyllamas()
4+
server = ServerPreset.tinyllama2()
55

66

77
@pytest.fixture(scope="module", autouse=True)
88
def create_server():
99
global server
10-
server = ServerPreset.tinyllamas()
10+
server = ServerPreset.tinyllama2()
1111

1212

1313
def test_server_start_simple():
@@ -23,3 +23,12 @@ def test_server_props():
2323
res = server.make_request("GET", "/props")
2424
assert res.status_code == 200
2525
assert res.body["total_slots"] == server.n_slots
26+
27+
28+
def test_server_models():
29+
global server
30+
server.start()
31+
res = server.make_request("GET", "/models")
32+
assert res.status_code == 200
33+
assert len(res.body["data"]) == 1
34+
assert res.body["data"][0]["id"] == server.model_alias
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import pytest
2+
from openai import OpenAI
3+
from utils import *
4+
5+
server = ServerPreset.tinyllama2()
6+
7+
8+
@pytest.fixture(scope="module", autouse=True)
9+
def create_server():
10+
global server
11+
server = ServerPreset.tinyllama2()
12+
13+
14+
@pytest.mark.parametrize(
15+
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,truncated",
16+
[
17+
("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, False),
18+
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, False),
19+
]
20+
)
21+
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, truncated):
22+
global server
23+
server.start()
24+
res = server.make_request("POST", "/chat/completions", data={
25+
"model": model,
26+
"max_tokens": max_tokens,
27+
"messages": [
28+
{"role": "system", "content": system_prompt},
29+
{"role": "user", "content": user_prompt},
30+
],
31+
})
32+
assert res.status_code == 200
33+
assert res.body["usage"]["prompt_tokens"] == n_prompt
34+
assert res.body["usage"]["completion_tokens"] == n_predicted
35+
choice = res.body["choices"][0]
36+
assert "assistant" == choice["message"]["role"]
37+
assert match_regex(re_content, choice["message"]["content"])
38+
if truncated:
39+
assert choice["finish_reason"] == "length"
40+
else:
41+
assert choice["finish_reason"] == "stop"
42+
43+
44+
@pytest.mark.parametrize(
45+
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,truncated",
46+
[
47+
("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, False),
48+
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, False),
49+
]
50+
)
51+
def test_chat_completion_stream(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, truncated):
52+
global server
53+
server.start()
54+
res = server.make_stream_request("POST", "/chat/completions", data={
55+
"model": model,
56+
"max_tokens": max_tokens,
57+
"messages": [
58+
{"role": "system", "content": system_prompt},
59+
{"role": "user", "content": user_prompt},
60+
],
61+
"stream": True,
62+
})
63+
content = ""
64+
for data in res:
65+
choice = data["choices"][0]
66+
if choice["finish_reason"] in ["stop", "length"]:
67+
assert data["usage"]["prompt_tokens"] == n_prompt
68+
assert data["usage"]["completion_tokens"] == n_predicted
69+
assert "content" not in choice["delta"]
70+
assert match_regex(re_content, content)
71+
# FIXME: not sure why this is incorrect in stream mode
72+
# if truncated:
73+
# assert choice["finish_reason"] == "length"
74+
# else:
75+
# assert choice["finish_reason"] == "stop"
76+
else:
77+
assert choice["finish_reason"] is None
78+
content += choice["delta"]["content"]
79+
80+
81+
def test_chat_completion_with_openai_library():
82+
global server
83+
server.start()
84+
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
85+
res = client.chat.completions.create(
86+
model="gpt-3.5-turbo-instruct",
87+
messages=[
88+
{"role": "system", "content": "Book"},
89+
{"role": "user", "content": "What is the best book"},
90+
],
91+
max_tokens=8,
92+
seed=42,
93+
temperature=0.8,
94+
)
95+
print(res)
96+
assert res.choices[0].finish_reason == "stop"
97+
assert res.choices[0].message.content is not None
98+
assert match_regex("(Suddenly)+", res.choices[0].message.content)
99+
100+
101+
@pytest.mark.parametrize("response_format,n_predicted,re_content", [
102+
({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\""),
103+
({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]"),
104+
({"type": "json_object"}, 10, "(\\{|John)+"),
105+
({"type": "sound"}, 0, None),
106+
# invalid response format (expected to fail)
107+
({"type": "json_object", "schema": 123}, 0, None),
108+
({"type": "json_object", "schema": {"type": 123}}, 0, None),
109+
({"type": "json_object", "schema": {"type": "hiccup"}}, 0, None),
110+
])
111+
def test_completion_with_response_format(response_format: dict, n_predicted: int, re_content: str | None):
112+
global server
113+
server.start()
114+
res = server.make_request("POST", "/chat/completions", data={
115+
"max_tokens": n_predicted,
116+
"messages": [
117+
{"role": "system", "content": "You are a coding assistant."},
118+
{"role": "user", "content": "Write an example"},
119+
],
120+
"response_format": response_format,
121+
})
122+
if re_content is not None:
123+
assert res.status_code == 200
124+
choice = res.body["choices"][0]
125+
assert match_regex(re_content, choice["message"]["content"])
126+
else:
127+
assert res.status_code != 200
128+
assert "error" in res.body
129+

examples/server/tests/unit/test_completion.py

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
from openai import OpenAI
33
from utils import *
44

5-
server = ServerPreset.tinyllamas()
5+
server = ServerPreset.tinyllama2()
66

77

88
@pytest.fixture(scope="module", autouse=True)
99
def create_server():
1010
global server
11-
server = ServerPreset.tinyllamas()
11+
server = ServerPreset.tinyllama2()
1212

1313

1414
@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [
@@ -61,10 +61,62 @@ def test_completion_with_openai_library():
6161
res = client.completions.create(
6262
model="gpt-3.5-turbo-instruct",
6363
prompt="I believe the meaning of life is",
64-
n=8,
64+
max_tokens=8,
6565
seed=42,
6666
temperature=0.8,
6767
)
6868
print(res)
6969
assert res.choices[0].finish_reason == "length"
7070
assert match_regex("(going|bed)+", res.choices[0].text)
71+
72+
73+
@pytest.mark.parametrize("n_slots", [1, 2])
74+
def test_consistent_result_same_seed(n_slots: int):
75+
global server
76+
server.n_slots = n_slots
77+
server.start()
78+
last_res = None
79+
for _ in range(4):
80+
res = server.make_request("POST", "/completion", data={
81+
"prompt": "I believe the meaning of life is",
82+
"seed": 42,
83+
"temperature": 1.0,
84+
})
85+
if last_res is not None:
86+
assert res.body["content"] == last_res.body["content"]
87+
last_res = res
88+
89+
90+
@pytest.mark.parametrize("n_slots", [1, 2])
91+
def test_different_result_different_seed(n_slots: int):
92+
global server
93+
server.n_slots = n_slots
94+
server.start()
95+
last_res = None
96+
for seed in range(4):
97+
res = server.make_request("POST", "/completion", data={
98+
"prompt": "I believe the meaning of life is",
99+
"seed": seed,
100+
"temperature": 1.0,
101+
})
102+
if last_res is not None:
103+
assert res.body["content"] != last_res.body["content"]
104+
last_res = res
105+
106+
107+
@pytest.mark.parametrize("n_batch", [16, 32])
108+
@pytest.mark.parametrize("temperature", [0.0, 1.0])
109+
def test_consistent_result_different_batch_size(n_batch: int, temperature: float):
110+
global server
111+
server.n_batch = n_batch
112+
server.start()
113+
last_res = None
114+
for _ in range(4):
115+
res = server.make_request("POST", "/completion", data={
116+
"prompt": "I believe the meaning of life is",
117+
"seed": 42,
118+
"temperature": temperature,
119+
})
120+
if last_res is not None:
121+
assert res.body["content"] == last_res.body["content"]
122+
last_res = res
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import pytest
2+
from openai import OpenAI
3+
from utils import *
4+
5+
server = ServerPreset.bert_bge_small()
6+
7+
8+
@pytest.fixture(scope="module", autouse=True)
9+
def create_server():
10+
global server
11+
server = ServerPreset.bert_bge_small()
12+
13+
14+
def test_embedding_single():
15+
global server
16+
server.start()
17+
res = server.make_request("POST", "/embeddings", data={
18+
"input": "I believe the meaning of life is",
19+
})
20+
assert res.status_code == 200
21+
assert len(res.body['data']) == 1
22+
assert 'embedding' in res.body['data'][0]
23+
assert len(res.body['data'][0]['embedding']) > 1
24+
25+
# make sure embedding vector is normalized
26+
assert abs(sum([x ** 2 for x in res.body['data'][0]['embedding']]) - 1) < 1e-6
27+
28+
29+
def test_embedding_multiple():
30+
global server
31+
server.start()
32+
res = server.make_request("POST", "/embeddings", data={
33+
"input": [
34+
"I believe the meaning of life is",
35+
"Write a joke about AI from a very long prompt which will not be truncated",
36+
"This is a test",
37+
"This is another test",
38+
],
39+
})
40+
assert res.status_code == 200
41+
assert len(res.body['data']) == 4
42+
for d in res.body['data']:
43+
assert 'embedding' in d
44+
assert len(d['embedding']) > 1
45+
46+
47+
def test_embedding_openai_library_single():
48+
global server
49+
server.start()
50+
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
51+
res = client.embeddings.create(model="text-embedding-3-small", input="I believe the meaning of life is")
52+
assert len(res.data) == 1
53+
assert len(res.data[0].embedding) > 1
54+
55+
56+
def test_embedding_openai_library_multiple():
57+
global server
58+
server.start()
59+
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
60+
res = client.embeddings.create(model="text-embedding-3-small", input=[
61+
"I believe the meaning of life is",
62+
"Write a joke about AI from a very long prompt which will not be truncated",
63+
"This is a test",
64+
"This is another test",
65+
])
66+
assert len(res.data) == 4
67+
for d in res.data:
68+
assert len(d.embedding) > 1
69+
70+
71+
def test_embedding_error_prompt_too_long():
72+
global server
73+
server.start()
74+
res = server.make_request("POST", "/embeddings", data={
75+
"input": "This is a test " * 512,
76+
})
77+
assert res.status_code != 200
78+
assert "too large" in res.body["error"]["message"]
79+
80+
81+
def test_same_prompt_give_same_result():
82+
server.start()
83+
res = server.make_request("POST", "/embeddings", data={
84+
"input": [
85+
"I believe the meaning of life is",
86+
"I believe the meaning of life is",
87+
"I believe the meaning of life is",
88+
"I believe the meaning of life is",
89+
"I believe the meaning of life is",
90+
],
91+
})
92+
assert res.status_code == 200
93+
assert len(res.body['data']) == 5
94+
for i in range(1, len(res.body['data'])):
95+
v0 = res.body['data'][0]['embedding']
96+
vi = res.body['data'][i]['embedding']
97+
for x, y in zip(v0, vi):
98+
assert abs(x - y) < 1e-6

0 commit comments

Comments
 (0)