Skip to content

Commit 472e128

Browse files
committed
added all sequential tests
1 parent eb02373 commit 472e128

File tree

8 files changed

+332
-6
lines changed

8 files changed

+332
-6
lines changed

examples/server/tests/unit/test_completion.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,3 +120,5 @@ def test_consistent_result_different_batch_size(n_batch: int, temperature: float
120120
if last_res is not None:
121121
assert res.body["content"] == last_res.body["content"]
122122
last_res = res
123+
124+
# TODO: add completion with tokens as input, mixed token+string input
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import pytest
2+
from utils import *
3+
4+
server = ServerPreset.tinyllama2()
5+
6+
7+
LONG_TEXT = """
8+
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
9+
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
10+
Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
11+
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
12+
""".strip()
13+
14+
@pytest.fixture(scope="module", autouse=True)
15+
def create_server():
16+
global server
17+
server = ServerPreset.tinyllama2()
18+
server.n_ctx = 256
19+
server.n_slots = 2
20+
21+
22+
def test_ctx_shift_enabled():
23+
# the prompt is 301 tokens
24+
# the slot context is 256/2 = 128 tokens
25+
# the prompt is truncated to keep the last 109 tokens
26+
# 64 tokens are generated thanks to shifting the context when it gets full
27+
global server
28+
server.start()
29+
res = server.make_request("POST", "/completion", data={
30+
"n_predict": 64,
31+
"prompt": LONG_TEXT,
32+
})
33+
assert res.status_code == 200
34+
assert res.body["timings"]["prompt_n"] == 109
35+
assert res.body["timings"]["predicted_n"] == 64
36+
assert res.body["truncated"] is True
37+
38+
39+
@pytest.mark.parametrize("n_predict,n_token_output,truncated", [
40+
(64, 64, False),
41+
(-1, 120, True),
42+
])
43+
def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, truncated: bool):
44+
global server
45+
server.disable_ctx_shift = True
46+
server.n_predict = -1
47+
server.start()
48+
res = server.make_request("POST", "/completion", data={
49+
"n_predict": n_predict,
50+
"prompt": "Hi how are you",
51+
})
52+
assert res.status_code == 200
53+
assert res.body["timings"]["predicted_n"] == n_token_output
54+
assert res.body["truncated"] == truncated
55+
56+
57+
def test_ctx_shift_disabled_long_prompt():
58+
global server
59+
server.disable_ctx_shift = True
60+
server.start()
61+
res = server.make_request("POST", "/completion", data={
62+
"n_predict": 64,
63+
"prompt": LONG_TEXT,
64+
})
65+
assert res.status_code != 200
66+
assert "error" in res.body
67+
assert "exceeds the available context size" in res.body["error"]["message"]

examples/server/tests/unit/test_embedding.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
server = ServerPreset.bert_bge_small()
66

7+
EPSILON = 1e-3
78

89
@pytest.fixture(scope="module", autouse=True)
910
def create_server():
@@ -23,7 +24,7 @@ def test_embedding_single():
2324
assert len(res.body['data'][0]['embedding']) > 1
2425

2526
# make sure embedding vector is normalized
26-
assert abs(sum([x ** 2 for x in res.body['data'][0]['embedding']]) - 1) < 1e-5
27+
assert abs(sum([x ** 2 for x in res.body['data'][0]['embedding']]) - 1) < EPSILON
2728

2829

2930
def test_embedding_multiple():
@@ -95,4 +96,4 @@ def test_same_prompt_give_same_result():
9596
v0 = res.body['data'][0]['embedding']
9697
vi = res.body['data'][i]['embedding']
9798
for x, y in zip(v0, vi):
98-
assert abs(x - y) < 1e-5
99+
assert abs(x - y) < EPSILON
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import pytest
2+
from utils import *
3+
4+
server = ServerPreset.tinyllama_infill()
5+
6+
@pytest.fixture(scope="module", autouse=True)
7+
def create_server():
8+
global server
9+
server = ServerPreset.tinyllama_infill()
10+
11+
def test_infill_without_input_extra():
12+
global server
13+
server.start()
14+
res = server.make_request("POST", "/infill", data={
15+
"prompt": "Complete this",
16+
"input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_",
17+
"input_suffix": "}\n",
18+
})
19+
assert res.status_code == 200
20+
assert match_regex("(One|day|she|saw|big|scary|bird)+", res.body["content"])
21+
22+
def test_infill_with_input_extra():
23+
global server
24+
server.start()
25+
res = server.make_request("POST", "/infill", data={
26+
"prompt": "Complete this",
27+
"input_extra": [{
28+
"filename": "llama.h",
29+
"text": "LLAMA_API int32_t llama_n_threads();\n"
30+
}],
31+
"input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_",
32+
"input_suffix": "}\n",
33+
})
34+
assert res.status_code == 200
35+
assert match_regex("(cuts|Jimmy|mom|came|into|the|room)+", res.body["content"])
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import pytest
2+
import os
3+
from utils import *
4+
5+
server = ServerPreset.stories15m_moe()
6+
7+
LORA_FILE_URL = "https://huggingface.co/ggml-org/stories15M_MOE/resolve/main/moe_shakespeare15M.gguf"
8+
9+
@pytest.fixture(scope="module", autouse=True)
10+
def create_server():
11+
global server
12+
server = ServerPreset.stories15m_moe()
13+
# download lora file if needed
14+
file_name = LORA_FILE_URL.split('/').pop()
15+
lora_file = f'../../../{file_name}'
16+
if not os.path.exists(lora_file):
17+
print(f"Downloading {LORA_FILE_URL} to {lora_file}")
18+
with open(lora_file, 'wb') as f:
19+
f.write(requests.get(LORA_FILE_URL).content)
20+
print(f"Done downloading lora file")
21+
server.lora_files = [lora_file]
22+
23+
24+
@pytest.mark.parametrize("scale,re_content", [
25+
# without applying lora, the model should behave like a bedtime story generator
26+
(0.0, "(little|girl|three|years|old)+"),
27+
# with lora, the model should behave like a Shakespearean text generator
28+
(1.0, "(eye|love|glass|sun)+"),
29+
])
30+
def test_lora(scale: float, re_content: str):
31+
global server
32+
server.start()
33+
res_lora_control = server.make_request("POST", "/lora-adapters", data=[
34+
{"id": 0, "scale": scale}
35+
])
36+
assert res_lora_control.status_code == 200
37+
res = server.make_request("POST", "/completion", data={
38+
"prompt": "Look in thy glass",
39+
})
40+
assert res.status_code == 200
41+
assert match_regex(re_content, res.body["content"])
42+
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import pytest
2+
from utils import *
3+
4+
server = ServerPreset.jina_reranker_tiny()
5+
6+
7+
@pytest.fixture(scope="module", autouse=True)
8+
def create_server():
9+
global server
10+
server = ServerPreset.jina_reranker_tiny()
11+
12+
13+
def test_rerank():
14+
global server
15+
server.start()
16+
res = server.make_request("POST", "/rerank", data={
17+
"query": "Machine learning is",
18+
"documents": [
19+
"A machine is a physical system that uses power to apply forces and control movement to perform an action. The term is commonly applied to artificial devices, such as those employing engines or motors, but also to natural biological macromolecules, such as molecular machines.",
20+
"Learning is the process of acquiring new understanding, knowledge, behaviors, skills, values, attitudes, and preferences. The ability to learn is possessed by humans, non-human animals, and some machines; there is also evidence for some kind of learning in certain plants.",
21+
"Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions.",
22+
"Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine."
23+
]
24+
})
25+
assert res.status_code == 200
26+
assert len(res.body["results"]) == 4
27+
28+
most_relevant = res.body["results"][0]
29+
least_relevant = res.body["results"][0]
30+
for doc in res.body["results"]:
31+
if doc["relevance_score"] > most_relevant["relevance_score"]:
32+
most_relevant = doc
33+
if doc["relevance_score"] < least_relevant["relevance_score"]:
34+
least_relevant = doc
35+
36+
assert most_relevant["relevance_score"] > least_relevant["relevance_score"]
37+
assert most_relevant["index"] == 2
38+
assert least_relevant["index"] == 3
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import pytest
2+
from utils import *
3+
4+
server = ServerPreset.tinyllama2()
5+
6+
@pytest.fixture(scope="module", autouse=True)
7+
def create_server():
8+
global server
9+
server = ServerPreset.tinyllama2()
10+
server.slot_save_path = "./tmp"
11+
12+
13+
def test_slot_save_restore():
14+
global server
15+
server.start()
16+
17+
# First prompt in slot 1 should be fully processed
18+
res = server.make_request("POST", "/completion", data={
19+
"prompt": "What is the capital of France?",
20+
"id_slot": 1,
21+
"cache_prompt": True,
22+
})
23+
assert res.status_code == 200
24+
assert match_regex("(Lily|cake)+", res.body["content"])
25+
assert res.body["timings"]["prompt_n"] == 21 # all tokens are processed
26+
27+
# Save state of slot 1
28+
res = server.make_request("POST", "/slots/1?action=save", data={
29+
"filename": "slot1.bin",
30+
})
31+
assert res.status_code == 200
32+
assert res.body["n_saved"] == 84
33+
34+
# Since we have cache, this should only process the last tokens
35+
res = server.make_request("POST", "/completion", data={
36+
"prompt": "What is the capital of Germany?",
37+
"id_slot": 1,
38+
"cache_prompt": True,
39+
})
40+
assert res.status_code == 200
41+
assert match_regex("(Jack|said)+", res.body["content"])
42+
assert res.body["timings"]["prompt_n"] == 6 # only different part is processed
43+
44+
# Loading the saved cache into slot 0
45+
res = server.make_request("POST", "/slots/0?action=restore", data={
46+
"filename": "slot1.bin",
47+
})
48+
assert res.status_code == 200
49+
assert res.body["n_restored"] == 84
50+
51+
# Since we have cache, slot 0 should only process the last tokens
52+
res = server.make_request("POST", "/completion", data={
53+
"prompt": "What is the capital of Germany?",
54+
"id_slot": 0,
55+
"cache_prompt": True,
56+
})
57+
assert res.status_code == 200
58+
assert match_regex("(Jack|said)+", res.body["content"])
59+
assert res.body["timings"]["prompt_n"] == 6 # only different part is processed
60+
61+
# For verification that slot 1 was not corrupted during slot 0 load, same thing should work
62+
res = server.make_request("POST", "/completion", data={
63+
"prompt": "What is the capital of Germany?",
64+
"id_slot": 1,
65+
"cache_prompt": True,
66+
})
67+
assert res.status_code == 200
68+
assert match_regex("(Jack|said)+", res.body["content"])
69+
assert res.body["timings"]["prompt_n"] == 1
70+
71+
72+
def test_slot_erase():
73+
global server
74+
server.start()
75+
76+
res = server.make_request("POST", "/completion", data={
77+
"prompt": "What is the capital of France?",
78+
"id_slot": 1,
79+
"cache_prompt": True,
80+
})
81+
assert res.status_code == 200
82+
assert match_regex("(Lily|cake)+", res.body["content"])
83+
assert res.body["timings"]["prompt_n"] == 21 # all tokens are processed
84+
85+
# erase slot 1
86+
res = server.make_request("POST", "/slots/1?action=erase")
87+
assert res.status_code == 200
88+
89+
# re-run the same prompt, it should process all tokens again
90+
res = server.make_request("POST", "/completion", data={
91+
"prompt": "What is the capital of France?",
92+
"id_slot": 1,
93+
"cache_prompt": True,
94+
})
95+
assert res.status_code == 200
96+
assert match_regex("(Lily|cake)+", res.body["content"])
97+
assert res.body["timings"]["prompt_n"] == 21 # all tokens are processed

examples/server/tests/utils.py

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
ContextManager,
1818
Iterable,
1919
Iterator,
20+
List,
2021
Literal,
2122
Sequence,
2223
Set,
@@ -65,7 +66,7 @@ class ServerProcess:
6566
draft: int | None = None
6667
api_key: str | None = None
6768
response_format: str | None = None
68-
lora_file: str | None = None
69+
lora_files: List[str] | None = None
6970
disable_ctx_shift: int | None = False
7071

7172
# session variables
@@ -134,8 +135,9 @@ def start(self, timeout_seconds: int = 10) -> None:
134135
server_args.extend(["--grp-attn-w", self.n_ga_w])
135136
if self.debug:
136137
server_args.append("--verbose")
137-
if self.lora_file:
138-
server_args.extend(["--lora", self.lora_file])
138+
if self.lora_files:
139+
for lora_file in self.lora_files:
140+
server_args.extend(["--lora", lora_file])
139141
if self.disable_ctx_shift:
140142
server_args.extend(["--no-context-shift"])
141143
if self.api_key:
@@ -202,7 +204,7 @@ def make_request(
202204
self,
203205
method: str,
204206
path: str,
205-
data: dict | None = None,
207+
data: dict | Any | None = None,
206208
headers: dict | None = None,
207209
) -> ServerResponse:
208210
url = f"http://{self.server_host}:{self.server_port}{path}"
@@ -277,6 +279,48 @@ def bert_bge_small() -> ServerProcess:
277279
server.server_embeddings = True
278280
return server
279281

282+
@staticmethod
283+
def tinyllama_infill() -> ServerProcess:
284+
server = ServerProcess()
285+
server.model_hf_repo = "ggml-org/models"
286+
server.model_hf_file = "tinyllamas/stories260K-infill.gguf"
287+
server.model_alias = "tinyllama-infill"
288+
server.n_ctx = 2048
289+
server.n_batch = 1024
290+
server.n_slots = 1
291+
server.n_predict = 64
292+
server.temperature = 0.0
293+
server.seed = 42
294+
return server
295+
296+
@staticmethod
297+
def stories15m_moe() -> ServerProcess:
298+
server = ServerProcess()
299+
server.model_hf_repo = "ggml-org/stories15M_MOE"
300+
server.model_hf_file = "stories15M_MOE-F16.gguf"
301+
server.model_alias = "stories15m-moe"
302+
server.n_ctx = 2048
303+
server.n_batch = 1024
304+
server.n_slots = 1
305+
server.n_predict = 64
306+
server.temperature = 0.0
307+
server.seed = 42
308+
return server
309+
310+
@staticmethod
311+
def jina_reranker_tiny() -> ServerProcess:
312+
server = ServerProcess()
313+
server.model_hf_repo = "ggml-org/models"
314+
server.model_hf_file = "jina-reranker-v1-tiny-en/ggml-model-f16.gguf"
315+
server.model_alias = "jina-reranker"
316+
server.model_file = "./tmp/jina-reranker-v1-tiny-en.gguf"
317+
server.n_ctx = 512
318+
server.n_batch = 512
319+
server.n_slots = 1
320+
server.seed = 42
321+
server.server_reranking = True
322+
return server
323+
280324

281325
def multiple_post_requests(
282326
server: ServerProcess, path: str, data: Sequence[dict], headers: dict | None = None

0 commit comments

Comments
 (0)