Skip to content

Commit fb7bf27

Browse files
committed
merge from main
1 parent 54f128a commit fb7bf27

File tree

1 file changed

+21
-28
lines changed

1 file changed

+21
-28
lines changed

common/chat.cpp

Lines changed: 21 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -721,23 +721,16 @@ static void foreach_function(const json & tools, const std::function<void(const
721721

722722
static std::string apply(
723723
const common_chat_template & tmpl,
724-
const struct templates_params & inputs,
725-
const std::optional<json> & messages_override = std::nullopt,
726-
const std::optional<json> & tools_override = std::nullopt,
727-
const std::optional<json> & additional_context = std::nullopt)
724+
const nlohmann::ordered_json & messages,
725+
const nlohmann::ordered_json & tools,
726+
bool add_generation_prompt,
727+
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json())
728728
{
729729
minja::chat_template_inputs tmpl_inputs;
730-
tmpl_inputs.messages = messages_override ? *messages_override : inputs.messages;
731-
if (tools_override) {
732-
tmpl_inputs.tools = *tools_override;
733-
} else {
734-
tmpl_inputs.tools = inputs.tools.empty() ? json() : inputs.tools;
735-
}
736-
tmpl_inputs.add_generation_prompt = inputs.add_generation_prompt;
737-
tmpl_inputs.extra_context = inputs.extra_context;
738-
if (additional_context) {
739-
tmpl_inputs.extra_context.merge_patch(*additional_context);
740-
}
730+
tmpl_inputs.messages = messages;
731+
tmpl_inputs.tools = tools;
732+
tmpl_inputs.add_generation_prompt = add_generation_prompt;
733+
tmpl_inputs.extra_context = extra_context;
741734
// TODO: add flag to control date/time, if only for testing purposes.
742735
// tmpl_inputs.now = std::chrono::system_clock::now();
743736

@@ -836,7 +829,7 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp
836829
inputs.messages,
837830
"Respond in JSON format, either with `tool_call` (a request to call tools) or with `response` reply to the user's request");
838831

839-
data.prompt = apply(tmpl, inputs, /* messages_override= */ tweaked_messages);
832+
data.prompt = apply(tmpl, tweaked_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, inputs.extra_context);
840833
data.format = COMMON_CHAT_FORMAT_GENERIC;
841834
return data;
842835
}
@@ -912,7 +905,7 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat
912905
data.preserved_tokens = {
913906
"[TOOL_CALLS]",
914907
};
915-
data.prompt = apply(tmpl, inputs);
908+
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
916909
data.format = COMMON_CHAT_FORMAT_MISTRAL_NEMO;
917910
return data;
918911
}
@@ -942,7 +935,7 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
942935
adjusted_messages.push_back(msg);
943936
}
944937
}
945-
data.prompt = apply(tmpl, inputs, /* messages_override= */ adjusted_messages);
938+
data.prompt = apply(tmpl, adjusted_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {});
946939
data.format = COMMON_CHAT_FORMAT_COMMAND_R7B;
947940
if (string_ends_with(data.prompt, "<|START_THINKING|>")) {
948941
if (!inputs.enable_thinking) {
@@ -1130,7 +1123,7 @@ static common_chat_params common_chat_params_init_llama_3_x(const common_chat_te
11301123
} else {
11311124
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
11321125
}
1133-
data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, /* tools_override= */ std::nullopt, json {
1126+
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {
11341127
{"date_string", format_time(inputs.now, "%d %b %Y")},
11351128
{"tools_in_user_message", false},
11361129
{"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
@@ -1195,7 +1188,7 @@ static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool w
11951188

11961189
static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct templates_params & inputs) {
11971190
common_chat_params data;
1198-
auto prompt = apply(tmpl, inputs);
1191+
auto prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
11991192

12001193
// Hacks to fix the official (broken) prompt.
12011194
// It is advisable to use --chat-template-file models/templates/llama-cpp-deepseek-r1.jinja instead,
@@ -1290,7 +1283,7 @@ static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
12901283
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
12911284
LOG_DBG("%s\n", __func__);
12921285
common_chat_params data;
1293-
data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, /* tools_override= */ json(), json {
1286+
data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
12941287
{"datetime", format_time(inputs.now, "%b %d %Y %H:%M:%S GMT")},
12951288
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
12961289
});
@@ -1346,7 +1339,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
13461339
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
13471340
// If the function is python, we also allow raw python code (if the line after `python\n` doesn't start w/ opening `{`), which the model seems to prefer for multiline code.
13481341
common_chat_params data;
1349-
data.prompt = apply(tmpl, inputs);
1342+
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
13501343
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2;
13511344
if (inputs.tools.is_array() && !inputs.tools.empty()) {
13521345
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
@@ -1473,7 +1466,7 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
14731466
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
14741467
}
14751468

1476-
data.prompt = apply(tmpl, inputs);
1469+
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
14771470
// TODO: if (has_raw_python)
14781471
return data;
14791472
}
@@ -1506,15 +1499,15 @@ static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser
15061499
static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) {
15071500
common_chat_params data;
15081501

1509-
json extra_context = json {
1502+
json additional_context = {
15101503
{"enable_thinking", inputs.enable_thinking},
15111504
};
1512-
extra_context.update(inputs.extra_context);
1505+
additional_context.update(inputs.extra_context);
15131506

1514-
data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, /* tools_override= */ std::nullopt, extra_context);
1507+
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, additional_context);
15151508
data.format = COMMON_CHAT_FORMAT_HERMES_2_PRO;
15161509
if (string_ends_with(data.prompt, "<think>\n")) {
1517-
if (!extra_context["enable_thinking"]) {
1510+
if (!additional_context["enable_thinking"]) {
15181511
data.prompt += "</think>";
15191512
} else {
15201513
data.thinking_forced_open = true;
@@ -1700,7 +1693,7 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
17001693

17011694
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
17021695
common_chat_params data;
1703-
data.prompt = apply(tmpl, inputs);
1696+
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, inputs.extra_context);
17041697
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
17051698
data.grammar_lazy = false;
17061699
if (!inputs.json_schema.is_null()) {

0 commit comments

Comments
 (0)