Skip to content

Commit 9162380

Browse files
author
ochafik
committed
Merge branch 'fix-gen-prompt' into enable-thinking
2 parents 9cdeebe + 8a25f79 commit 9162380

File tree

2 files changed

+26
-0
lines changed

2 files changed

+26
-0
lines changed

tools/server/tests/unit/test_template.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,28 @@ def test_date_inside_prompt(template_name: str, format: str, tools: list[dict]):
8181

8282
today_str = datetime.date.today().strftime(format)
8383
assert today_str in prompt, f"Expected today's date ({today_str}) in content ({prompt})"
84+
85+
86+
@pytest.mark.parametrize("add_generation_prompt", [False, True])
87+
@pytest.mark.parametrize("template_name,expected_generation_prompt", [
88+
("meta-llama-Llama-3.3-70B-Instruct", "<|start_header_id|>assistant<|end_header_id|>"),
89+
])
90+
def test_add_generation_prompt(template_name: str, expected_generation_prompt: str, add_generation_prompt: bool):
91+
global server
92+
server.jinja = True
93+
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
94+
server.start(timeout_seconds=TIMEOUT_SERVER_START)
95+
96+
res = server.make_request("POST", "/apply-template", data={
97+
"messages": [
98+
{"role": "user", "content": "What is today?"},
99+
],
100+
"add_generation_prompt": add_generation_prompt,
101+
})
102+
assert res.status_code == 200
103+
prompt = res.body["prompt"]
104+
105+
if add_generation_prompt:
106+
assert expected_generation_prompt in prompt, f"Expected generation prompt ({expected_generation_prompt}) in content ({prompt})"
107+
else:
108+
assert expected_generation_prompt not in prompt, f"Did not expect generation prompt ({expected_generation_prompt}) in content ({prompt})"

tools/server/utils.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -732,6 +732,7 @@ static json oaicompat_chat_params_parse(
732732
inputs.grammar = grammar;
733733
inputs.use_jinja = opt.use_jinja;
734734
inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
735+
inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true);
735736
inputs.reasoning_format = opt.reasoning_format;
736737
inputs.enable_thinking = opt.enable_thinking;
737738
if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && body.contains("grammar")) {

0 commit comments

Comments
 (0)