Skip to content

Commit c6bd7a7

Browse files
committed
add chat template test
1 parent 44f998a commit c6bd7a7

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

examples/server/tests/unit/test_chat_completion.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,23 @@ def test_chat_completion_with_openai_library():
100100
assert match_regex("(Suddenly)+", res.choices[0].message.content)
101101

102102

103+
def test_chat_template():
104+
global server
105+
server.chat_template = "llama3"
106+
server.debug = True # to get the "__verbose" object in the response
107+
server.start()
108+
res = server.make_request("POST", "/chat/completions", data={
109+
"max_tokens": 8,
110+
"messages": [
111+
{"role": "system", "content": "Book"},
112+
{"role": "user", "content": "What is the best book"},
113+
]
114+
})
115+
assert res.status_code == 200
116+
assert "__verbose" in res.body
117+
assert res.body["__verbose"]["prompt"] == "<s> <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
118+
119+
103120
@pytest.mark.parametrize("response_format,n_predicted,re_content", [
104121
({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\""),
105122
({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]"),

examples/server/tests/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ class ServerProcess:
7474
draft_min: int | None = None
7575
draft_max: int | None = None
7676
no_webui: bool | None = None
77+
chat_template: str | None = None
7778

7879
# session variables
7980
process: subprocess.Popen | None = None
@@ -164,6 +165,8 @@ def start(self, timeout_seconds: int = 10) -> None:
164165
server_args.extend(["--draft-min", self.draft_min])
165166
if self.no_webui:
166167
server_args.append("--no-webui")
168+
if self.chat_template:
169+
server_args.extend(["--chat-template", self.chat_template])
167170

168171
args = [str(arg) for arg in [server_path, *server_args]]
169172
print(f"bench: starting server with: {' '.join(args)}")

0 commit comments

Comments
 (0)