Skip to content

Commit 496d940

Browse files
committed
Added tests for start string feature
1 parent 0524a08 commit 496d940

File tree

2 files changed

+85
-0
lines changed

2 files changed

+85
-0
lines changed

examples/server/tests/unit/test_chat_completion.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,3 +309,85 @@ def test_logprobs_stream():
309309
assert token.top_logprobs is not None
310310
assert len(token.top_logprobs) > 0
311311
assert aggregated_text == output_text
312+
313+
314+
def test_startstring_serverconfig():
315+
global server
316+
server.jinja = False
317+
server.start_string=" 9 "
318+
server.start()
319+
res = server.make_request("POST", "/chat/completions", data={
320+
"max_tokens": 32,
321+
"messages": [
322+
{"role": "user", "content": "List the numbers from 1 to 100"},
323+
],
324+
"grammar": "root ::= \"1 2 3 4 5 6 7 8 9 10 11 12\"",
325+
})
326+
assert res.status_code == 200, res.body
327+
choice = res.body["choices"][0]
328+
content = choice["message"]["content"]
329+
print(content)
330+
assert content.startswith("10 ")
331+
332+
def test_startstring_clientconfig():
333+
global server
334+
server.jinja = False
335+
server.start()
336+
res = server.make_request("POST", "/chat/completions", data={
337+
"max_tokens": 32,
338+
"messages": [
339+
{"role": "user", "content": "List the numbers from 1 to 100"},
340+
],
341+
"grammar": "root ::= \"1 2 3 4 5 6 7 8 9 10 11 12\"",
342+
"start_strings": ["10"]
343+
})
344+
assert res.status_code == 200, res.body
345+
choice = res.body["choices"][0]
346+
content = choice["message"]["content"]
347+
assert content.startswith(" 11")
348+
349+
350+
def test_startstring_clientconfig_stream():
351+
global server
352+
server.jinja = False
353+
server.start()
354+
max_tokens=64
355+
system_prompt=""
356+
user_prompt=""
357+
res = server.make_stream_request("POST", "/chat/completions", data={
358+
"max_tokens": max_tokens,
359+
"messages": [
360+
{"role": "system", "content": system_prompt},
361+
{"role": "user", "content": user_prompt},
362+
],
363+
"grammar": "root ::= \"1 2 3 4 5 6 7 8 9 10 11 12\" .+",
364+
"start_strings": ["10"],
365+
"stream": True,
366+
})
367+
368+
content = ""
369+
last_cmpl_id = None
370+
for data in res:
371+
choice = data["choices"][0]
372+
if choice["finish_reason"] not in ["stop", "length"]:
373+
delta = choice["delta"]["content"]
374+
content += delta
375+
assert content.startswith(" 11")
376+
377+
378+
def test_startstring_clientconfig_multiple():
379+
global server
380+
server.jinja = False
381+
server.start()
382+
res = server.make_request("POST", "/chat/completions", data={
383+
"max_tokens": 32,
384+
"messages": [
385+
{"role": "user", "content": "List the numbers from 1 to 100"},
386+
],
387+
"grammar": "root ::= \"1 2 3 4 5 6 7 8 9 10 11 12\"",
388+
"start_strings": ["10","9"]
389+
})
390+
assert res.status_code == 200, res.body
391+
choice = res.body["choices"][0]
392+
content = choice["message"]["content"]
393+
assert content.startswith(" 10")

examples/server/tests/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ class ServerProcess:
8888
chat_template: str | None = None
8989
chat_template_file: str | None = None
9090
server_path: str | None = None
91+
start_string: str | None = None
9192

9293
# session variables
9394
process: subprocess.Popen | None = None
@@ -194,6 +195,8 @@ def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None:
194195
server_args.extend(["--chat-template", self.chat_template])
195196
if self.chat_template_file:
196197
server_args.extend(["--chat-template-file", self.chat_template_file])
198+
if self.start_string:
199+
server_args.extend(["--start_string", self.start_string])
197200

198201
args = [str(arg) for arg in [server_path, *server_args]]
199202
print(f"tests: starting server with: {' '.join(args)}")

0 commit comments

Comments
 (0)