@@ -309,3 +309,85 @@ def test_logprobs_stream():
309309 assert token .top_logprobs is not None
310310 assert len (token .top_logprobs ) > 0
311311 assert aggregated_text == output_text
312+
313+
314+ def test_startstring_serverconfig ():
315+ global server
316+ server .jinja = False
317+ server .start_string = " 9 "
318+ server .start ()
319+ res = server .make_request ("POST" , "/chat/completions" , data = {
320+ "max_tokens" : 32 ,
321+ "messages" : [
322+ {"role" : "user" , "content" : "List the numbers from 1 to 100" },
323+ ],
324+ "grammar" : "root ::= \" 1 2 3 4 5 6 7 8 9 10 11 12\" " ,
325+ })
326+ assert res .status_code == 200 , res .body
327+ choice = res .body ["choices" ][0 ]
328+ content = choice ["message" ]["content" ]
329+ print (content )
330+ assert content .startswith ("10 " )
331+
332+ def test_startstring_clientconfig ():
333+ global server
334+ server .jinja = False
335+ server .start ()
336+ res = server .make_request ("POST" , "/chat/completions" , data = {
337+ "max_tokens" : 32 ,
338+ "messages" : [
339+ {"role" : "user" , "content" : "List the numbers from 1 to 100" },
340+ ],
341+ "grammar" : "root ::= \" 1 2 3 4 5 6 7 8 9 10 11 12\" " ,
342+ "start_strings" : ["10" ]
343+ })
344+ assert res .status_code == 200 , res .body
345+ choice = res .body ["choices" ][0 ]
346+ content = choice ["message" ]["content" ]
347+ assert content .startswith (" 11" )
348+
349+
350+ def test_startstring_clientconfig_stream ():
351+ global server
352+ server .jinja = False
353+ server .start ()
354+ max_tokens = 64
355+ system_prompt = ""
356+ user_prompt = ""
357+ res = server .make_stream_request ("POST" , "/chat/completions" , data = {
358+ "max_tokens" : max_tokens ,
359+ "messages" : [
360+ {"role" : "system" , "content" : system_prompt },
361+ {"role" : "user" , "content" : user_prompt },
362+ ],
363+ "grammar" : "root ::= \" 1 2 3 4 5 6 7 8 9 10 11 12\" .+" ,
364+ "start_strings" : ["10" ],
365+ "stream" : True ,
366+ })
367+
368+ content = ""
369+ last_cmpl_id = None
370+ for data in res :
371+ choice = data ["choices" ][0 ]
372+ if choice ["finish_reason" ] not in ["stop" , "length" ]:
373+ delta = choice ["delta" ]["content" ]
374+ content += delta
375+ assert content .startswith (" 11" )
376+
377+
378+ def test_startstring_clientconfig_multiple ():
379+ global server
380+ server .jinja = False
381+ server .start ()
382+ res = server .make_request ("POST" , "/chat/completions" , data = {
383+ "max_tokens" : 32 ,
384+ "messages" : [
385+ {"role" : "user" , "content" : "List the numbers from 1 to 100" },
386+ ],
387+ "grammar" : "root ::= \" 1 2 3 4 5 6 7 8 9 10 11 12\" " ,
388+ "start_strings" : ["10" ,"9" ]
389+ })
390+ assert res .status_code == 200 , res .body
391+ choice = res .body ["choices" ][0 ]
392+ content = choice ["message" ]["content" ]
393+ assert content .startswith (" 10" )
0 commit comments