Skip to content

Commit 30ea359

Browse files
author
Olivier Chafik
committed
update to minja's new api
1 parent 11c1f0c commit 30ea359

File tree

4 files changed

+170
-12
lines changed

4 files changed

+170
-12
lines changed

common/chat-template.hpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,28 @@ class chat_template {
270270
const std::string & eos_token() const { return eos_token_; }
271271
const chat_template_caps & original_caps() const { return caps_; }
272272

273+
// Deprecated, please use the form with chat_template_inputs and chat_template_options
274+
std::string apply(
275+
const nlohmann::ordered_json & messages,
276+
const nlohmann::ordered_json & tools,
277+
bool add_generation_prompt,
278+
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json(),
279+
bool apply_polyfills = true)
280+
{
281+
fprintf(stderr, "[%s] Deprecated!\n", __func__);
282+
chat_template_inputs inputs;
283+
inputs.messages = messages;
284+
inputs.tools = tools;
285+
inputs.add_generation_prompt = add_generation_prompt;
286+
inputs.extra_context = extra_context;
287+
inputs.now = std::chrono::system_clock::now();
288+
289+
chat_template_options opts;
290+
opts.apply_polyfills = apply_polyfills;
291+
292+
return apply(inputs, opts);
293+
}
294+
273295
std::string apply(
274296
const chat_template_inputs & inputs,
275297
const chat_template_options & opts = chat_template_options()) const

common/chat.cpp

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,28 @@ static void foreach_function(const json & tools, const std::function<void(const
175175
}
176176
}
177177

178+
static std::string apply(
179+
const common_chat_template & tmpl,
180+
const nlohmann::ordered_json & messages,
181+
const nlohmann::ordered_json & tools,
182+
bool add_generation_prompt,
183+
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json())
184+
{
185+
minja::chat_template_inputs tmpl_inputs;
186+
tmpl_inputs.messages = messages;
187+
tmpl_inputs.tools = tools;
188+
tmpl_inputs.add_generation_prompt = add_generation_prompt;
189+
tmpl_inputs.extra_context = extra_context;
190+
// TODO: add flag to control date/time, if only for testing purposes.
191+
// tmpl_inputs.now = std::chrono::system_clock::now();
192+
193+
minja::chat_template_options tmpl_opts;
194+
tmpl_opts.use_bos_token = false;
195+
tmpl_opts.use_eos_token = false;
196+
197+
return tmpl.apply(tmpl_inputs, tmpl_opts);
198+
}
199+
178200
static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
179201
common_chat_params data;
180202

@@ -256,7 +278,7 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp
256278
inputs.messages,
257279
"Respond in JSON format, either with `tool_call` (a request to call tools) or with `response` reply to the user's request");
258280

259-
data.prompt = tmpl.apply(tweaked_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
281+
data.prompt = apply(tmpl, tweaked_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
260282
data.format = COMMON_CHAT_FORMAT_GENERIC;
261283
return data;
262284
}
@@ -322,7 +344,7 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat
322344
builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema));
323345
}, grammar_options);
324346
data.grammar_triggers.push_back({"[TOOL_CALLS]", /* .at_start = */ true});
325-
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
347+
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
326348
data.format = COMMON_CHAT_FORMAT_MISTRAL_NEMO;
327349
return data;
328350
}
@@ -372,7 +394,7 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
372394
"<|END_THINKING|>",
373395
"<|END_ACTION|>",
374396
};
375-
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
397+
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
376398
data.format = COMMON_CHAT_FORMAT_COMMAND_R7B;
377399
return data;
378400
}
@@ -489,7 +511,7 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com
489511
builder.add_rule("root", string_join(tool_rules, " | "));
490512
}, grammar_options);
491513
data.additional_stops.push_back("<|eom_id|>");
492-
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {
514+
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {
493515
{"tools_in_user_message", false},
494516
{"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
495517
});
@@ -568,7 +590,7 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_
568590
"<|tool▁call▁end|>",
569591
};
570592
}, grammar_options);
571-
auto prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
593+
auto prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
572594

573595
// Hacks to fix the official (broken) prompt.
574596
// It is advisable to use --chat-template-file models/templates/llama-cpp-deepseek-r1.jinja instead,
@@ -614,10 +636,10 @@ static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input)
614636
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
615637
fprintf(stderr, "%s\n", __func__);
616638
common_chat_params data;
617-
data.prompt = tmpl.apply(inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
639+
data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
618640
{"datetime", "Jan 29 2025 13:00:00 GMT"},
619641
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
620-
}, /* adjust_inputs= */ false);
642+
});
621643
if (!inputs.tools.is_null() && !inputs.tools.empty()) {
622644
data.grammar_lazy = inputs.tool_choice != "required";
623645
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
@@ -661,7 +683,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
661683
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
662684
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
663685
common_chat_params data;
664-
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
686+
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
665687
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2;
666688
if (!inputs.tools.is_null() && !inputs.tools.empty()) {
667689
data.grammar_lazy = inputs.tool_choice != "required";
@@ -788,7 +810,7 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
788810
data.grammar_triggers.push_back({"<function=", /* .at_start = */ false});
789811
}, grammar_options);
790812

791-
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
813+
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
792814
// TODO: if (has_raw_python)
793815
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1;
794816
return data;
@@ -843,7 +865,7 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
843865
data.preserved_tokens = { "</tool_call>" };
844866
}, grammar_options);
845867

846-
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
868+
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
847869
data.format = COMMON_CHAT_FORMAT_HERMES_2_PRO;
848870
return data;
849871
}
@@ -904,7 +926,7 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input)
904926

905927
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
906928
common_chat_params data;
907-
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
929+
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
908930
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
909931
data.grammar_lazy = false;
910932
if (!inputs.json_schema.is_null()) {

examples/run/run.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -848,7 +848,15 @@ static int apply_chat_template(const common_chat_template & tmpl, LlamaData & ll
848848
});
849849
}
850850
try {
851-
auto result = tmpl.apply(messages, /* tools= */ json(), append);
851+
minja::chat_template_inputs tmpl_inputs;
852+
tmpl_inputs.messages = messages;
853+
tmpl_inputs.add_generation_prompt = append;
854+
855+
minja::chat_template_options tmpl_opts;
856+
tmpl_opts.use_bos_token = false;
857+
tmpl_opts.use_eos_token = false;
858+
859+
auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
852860
llama_data.fmtted.resize(result.size() + 1);
853861
memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1);
854862
return result.size();

examples/server/tests/unit/test_tool_call.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,112 @@ def test_weather_tool_call(hf_repo: str, template_override: str | Tuple[str, str
340340
assert re.match('^Istanbul(, (TR|Turkey|Türkiye))?$', location), f'Expected Istanbul for location, got {location}'
341341

342342

343+
@pytest.mark.slow
344+
@pytest.mark.parametrize("hf_repo,template_override", [
345+
("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
346+
("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
347+
348+
("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
349+
("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),
350+
351+
("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
352+
("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
353+
354+
("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
355+
("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"),
356+
357+
("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
358+
("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"),
359+
360+
("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
361+
("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),
362+
363+
("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
364+
("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
365+
366+
("bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)),
367+
("bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"),
368+
369+
("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
370+
("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"),
371+
372+
# Note: gemma-2-2b-it knows itself as "model", not "assistant", so we don't test the ill-suited chatml on it.
373+
("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
374+
375+
# ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
376+
])
377+
def test_calc_result(hf_repo: str, template_override: str | Tuple[str, str | None] | None):
378+
global server
379+
n_predict = 512
380+
server.n_slots = 1
381+
server.jinja = True
382+
server.n_ctx = 8192
383+
server.n_predict = n_predict
384+
server.model_hf_repo = hf_repo
385+
server.model_hf_file = None
386+
if isinstance(template_override, tuple):
387+
(template_hf_repo, template_variant) = template_override
388+
server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
389+
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
390+
elif isinstance(template_override, str):
391+
server.chat_template = template_override
392+
server.start(timeout_seconds=TIMEOUT_SERVER_START)
393+
res = server.make_request("POST", "/chat/completions", data={
394+
"max_tokens": n_predict,
395+
"messages": [
396+
{"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things."},
397+
{"role": "user", "content": "What's the y coordinate of a point on the unit sphere at angle 30 degrees?"},
398+
{
399+
"role": "assistant",
400+
"content": None,
401+
"tool_calls": [
402+
{
403+
"function": {
404+
"name": "calculate",
405+
"arguments": "{\"expression\":\"sin(30 * pi / 180)\"}"
406+
}
407+
}
408+
]
409+
},
410+
{
411+
"role": "tool",
412+
"name": "calculate",
413+
"content": "0.5"
414+
}
415+
],
416+
"tools": [
417+
{
418+
"type":"function",
419+
"function":{
420+
"name":"calculate",
421+
"description":"A calculator function that computes values of arithmetic expressions in the Python syntax",
422+
"parameters":{
423+
"type":"object",
424+
"properties":{
425+
"expression":{
426+
"type":"string",
427+
"description":"An arithmetic expression to compute the value of (Python syntad, assuming all floats)"
428+
}
429+
},
430+
"required":["expression"]
431+
}
432+
}
433+
}
434+
]
435+
}, timeout=TIMEOUT_HTTP_REQUEST)
436+
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
437+
choice = res.body["choices"][0]
438+
tool_calls = choice["message"].get("tool_calls")
439+
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
440+
tool_call = tool_calls[0]
441+
assert tool_call["function"]["name"] == WEATHER_TOOL["function"]["name"]
442+
actual_arguments = json.loads(tool_call["function"]["arguments"])
443+
assert 'location' in actual_arguments, f"location not found in {json.dumps(actual_arguments)}"
444+
location = actual_arguments["location"]
445+
assert isinstance(location, str), f"Expected location to be a string, got {type(location)}: {json.dumps(location)}"
446+
assert re.match('^Istanbul(, (TR|Turkey|Türkiye))?$', location), f'Expected Istanbul for location, got {location}'
447+
448+
343449
@pytest.mark.slow
344450
@pytest.mark.parametrize("expected_arguments_override,hf_repo,template_override", [
345451
(None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),

0 commit comments

Comments
 (0)