Skip to content

Commit 34e4e22

Browse files
author
ochafik
committed
fix & test grammar & json_schema w/ & w/o --jinja
1 parent 76f5d27 commit 34e4e22

File tree

4 files changed

+45
-7
lines changed

4 files changed

+45
-7
lines changed

common/chat.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1431,7 +1431,7 @@ static common_chat_params common_chat_templates_apply_jinja(
14311431

14321432
// Use generic handler when mixing tools + JSON schema.
14331433
// TODO: support that mix in handlers below.
1434-
if ((!params.tools.is_array() && params.json_schema.is_object())) {
1434+
if ((params.tools.is_array() && params.json_schema.is_object())) {
14351435
return common_chat_params_init_generic(tmpl, params);
14361436
}
14371437

examples/server/server.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -329,9 +329,6 @@ struct server_task {
329329
}
330330

331331
// process "json_schema" and "grammar"
332-
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
333-
throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
334-
}
335332
if (data.contains("json_schema") && !data.contains("grammar")) {
336333
try {
337334
auto schema = json_value(data, "json_schema", json::object());

examples/server/tests/unit/test_chat_completion.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,47 @@ def test_completion_with_response_format(response_format: dict, n_predicted: int
169169
assert "error" in res.body
170170

171171

172+
@pytest.mark.parametrize("jinja,json_schema,n_predicted,re_content", [
173+
(False, {"const": "42"}, 6, "\"42\""),
174+
(True, {"const": "42"}, 6, "\"42\""),
175+
])
176+
def test_completion_with_json_schema(jinja: bool, json_schema: dict, n_predicted: int, re_content: str):
177+
global server
178+
server.jinja = jinja
179+
server.start()
180+
res = server.make_request("POST", "/chat/completions", data={
181+
"max_tokens": n_predicted,
182+
"messages": [
183+
{"role": "system", "content": "You are a coding assistant."},
184+
{"role": "user", "content": "Write an example"},
185+
],
186+
"json_schema": json_schema,
187+
})
188+
assert res.status_code == 200, f'Expected 200, got {res.status_code}'
189+
choice = res.body["choices"][0]
190+
assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}'
191+
192+
193+
@pytest.mark.parametrize("jinja,grammar,n_predicted,re_content", [
194+
(False, 'root ::= "a"{5,5}', 6, "a{5,5}"),
195+
(True, 'root ::= "a"{5,5}', 6, "a{5,5}"),
196+
])
197+
def test_completion_with_grammar(jinja: bool, grammar: str, n_predicted: int, re_content: str):
198+
global server
199+
server.jinja = jinja
200+
server.start()
201+
res = server.make_request("POST", "/chat/completions", data={
202+
"max_tokens": n_predicted,
203+
"messages": [
204+
{"role": "user", "content": "Does not matter what I say, does it?"},
205+
],
206+
"grammar": grammar,
207+
})
208+
assert res.status_code == 200, res.body
209+
choice = res.body["choices"][0]
210+
assert match_regex(re_content, choice["message"]["content"]), choice["message"]["content"]
211+
212+
172213
@pytest.mark.parametrize("messages", [
173214
None,
174215
"string",

examples/server/utils.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -571,8 +571,8 @@ static json oaicompat_completion_params_parse(
571571
llama_params["stop"] = json_value(body, "stop", json::array());
572572
}
573573

574-
auto json_schema = json_value(llama_params, "json_schema", json());
575-
auto grammar = json_value(llama_params, "grammar", std::string());
574+
auto json_schema = json_value(body, "json_schema", json());
575+
auto grammar = json_value(body, "grammar", std::string());
576576
if (!json_schema.is_null() && !grammar.empty()) {
577577
throw std::runtime_error("Cannot use both json_schema and grammar");
578578
}
@@ -601,7 +601,7 @@ static json oaicompat_completion_params_parse(
601601
inputs.use_jinja = use_jinja;
602602
inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
603603
inputs.extract_reasoning = reasoning_format != COMMON_REASONING_FORMAT_NONE;
604-
if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && llama_params.contains("grammar")) {
604+
if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && body.contains("grammar")) {
605605
throw std::runtime_error("Cannot use custom grammar constraints with tools.");
606606
}
607607

0 commit comments

Comments
 (0)