Skip to content

tool-call: Qwen3 Coder chat format support #15162

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 129 additions & 2 deletions common/chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,7 @@ const char * common_chat_format_name(common_chat_format format) {
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1";
case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro";
case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
case COMMON_CHAT_FORMAT_QWEN3: return "Qwen3";

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I though it is Qwen3 Coder specific format, not all Qwen3 models?

case COMMON_CHAT_FORMAT_GRANITE: return "Granite";
case COMMON_CHAT_FORMAT_GPT_OSS: return "GPT-OSS";
default:
Expand Down Expand Up @@ -1059,6 +1060,124 @@ static void common_chat_parse_command_r7b(common_chat_msg_parser & builder) {
}
}


static common_chat_params common_chat_params_init_qwen3(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;

json additional_context = {
{"enable_thinking", inputs.enable_thinking},
};

data.prompt = apply(tmpl, inputs);
data.format = COMMON_CHAT_FORMAT_QWEN3;
// if (string_ends_with(data.prompt, "<think>\n")) {
// if (!inputs.enable_thinking) {
// data.prompt += "</think>";
// } else {
// data.thinking_forced_open = true;
// }
// }

if (!inputs.tools.is_null()) {
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
std::string name = function.at("name");
auto parameters = function.at("parameters");
builder.resolve_refs(parameters);

std::vector<std::string> fragments;
fragments.push_back(" space \"<tool_call>\\n<function=" + name + ">\\n\"");

const auto & properties = parameters.at("properties");
std::vector<std::string> required;
if (parameters.contains("required")) {
required = parameters.at("required");
}

for (const auto & [param_name, param_value] : properties.items()) {
const auto optional = std::find(required.begin(), required.end(), param_name) == required.end();
if (optional) {
fragments.push_back("(");
}
fragments.push_back("\"<parameter=" + param_name + ">\\n\" ");
fragments.push_back(builder.add_schema(name + "-parameter-" + param_name, param_value));
fragments.push_back("\"\\n</parameter>\\n\"");
if (optional) {
fragments.push_back(")? ");
}
}
fragments.push_back("\"</function>\\n</tool_call>\"");

tool_rules.push_back(builder.add_rule(
name + "-function-tag",
string_join(fragments, " ")));
});
auto tool_call = tool_rules.size() == 1 ? tool_rules[0] : builder.add_rule("tool_call", string_join(tool_rules, " | "));
builder.add_rule("root",
std::string(data.thinking_forced_open ? "( \"</think>\" space )? " : "") +
(inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call));
data.grammar_triggers.push_back({
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
"<tool_call>"
});
// data.grammar_triggers.push_back({
// COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
// // If thinking_forced_open, then we capture the </think> tag in the grammar,
// // (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar)
// std::string(data.thinking_forced_open ? "[\\s\\S]*?(</think>\\s*)" : "(?:<think>[\\s\\S]*?</think>\\s*)?") + (
// "\\s*(<tool_call>\\n<function=)" // + string_join(tool_names, "|") + ")>\n)"
// ),
// });
data.preserved_tokens = {
"<think>",
"</think>",
"<tool_call>",
"</tool_call>",
};
});
}

return data;
}

static void common_chat_parse_qwen3(common_chat_msg_parser & builder) {
builder.try_parse_reasoning("<think>", "</think>");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}

static const common_regex function_open("\\s*<tool_call>\n<function=([a-zA-Z0-9_]+)>\n");
static const common_regex function_close("</function>\n</tool_call>");
static const common_regex parameter_open("<parameter=([a-zA-Z0-9_]+)>\n");
static const common_regex parameter_close("</parameter>\n");

while (auto block_open_match = builder.try_find_regex(function_open)) {
const auto function_name = builder.str(block_open_match->groups[1]);
json arguments = json::object();
while (true) {
builder.consume_spaces();
if (auto param_open_match = builder.try_consume_regex(parameter_open)) {
const auto parameter_name = builder.str(param_open_match->groups[1]);
const auto parameter_value = builder.consume_json();
arguments[parameter_name] = parameter_value.json;
builder.consume_spaces();
builder.consume_regex(parameter_close);
} else {
break;
}
}
builder.consume_regex(function_close);
builder.consume_spaces();
builder.add_tool_call(function_name, "", arguments.dump(2));
}
builder.add_content(builder.consume_rest());
}

static void expect_tool_parameters(const std::string & name, const json & parameters, const std::vector<std::string> & expected_properties) {
if (!parameters.is_object() || !parameters.contains("type") || parameters.at("type") != "object" || !parameters.contains("properties") || !parameters.contains("required")) {
throw std::runtime_error("Parameters of tool " + name + " must be an object w/ required properties");
Expand Down Expand Up @@ -1930,9 +2049,14 @@ static common_chat_params common_chat_templates_apply_jinja(
return common_chat_params_init_granite(tmpl, params);
}

// Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
if (src.find("<tool_call>") != std::string::npos && params.json_schema.is_null()) {
return common_chat_params_init_hermes_2_pro(tmpl, params);
if (src.find("<parameter=") != std::string::npos && params.json_schema.is_null()) {
// Qwen3-Coder
return common_chat_params_init_qwen3(tmpl, params);
} else {
// Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
return common_chat_params_init_hermes_2_pro(tmpl, params);
}
}

// GPT-OSS
Expand Down Expand Up @@ -2092,6 +2216,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
case COMMON_CHAT_FORMAT_COMMAND_R7B:
common_chat_parse_command_r7b(builder);
break;
case COMMON_CHAT_FORMAT_QWEN3:
common_chat_parse_qwen3(builder);
break;
case COMMON_CHAT_FORMAT_GRANITE:
common_chat_parse_granite(builder);
break;
Expand Down
1 change: 1 addition & 0 deletions common/chat.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ enum common_chat_format {
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
COMMON_CHAT_FORMAT_HERMES_2_PRO,
COMMON_CHAT_FORMAT_COMMAND_R7B,
COMMON_CHAT_FORMAT_QWEN3,
COMMON_CHAT_FORMAT_GRANITE,
COMMON_CHAT_FORMAT_GPT_OSS,

Expand Down
131 changes: 131 additions & 0 deletions models/templates/Qwen-Qwen3-Coder-30B-A3B-Instruct.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
{% macro render_item_list(item_list, tag_name='required') %}
{%- if item_list is defined and item_list is iterable and item_list | length > 0 %}
{%- if tag_name %}{{- '\n<' ~ tag_name ~ '>' -}}{% endif %}
{{- '[' }}
{%- for item in item_list -%}
{%- if loop.index > 1 %}{{- ", "}}{% endif -%}
{%- if item is string -%}
{{ "`" ~ item ~ "`" }}
{%- else -%}
{{ item }}
{%- endif -%}
{%- endfor -%}
{{- ']' }}
{%- if tag_name %}{{- '</' ~ tag_name ~ '>' -}}{% endif %}
{%- endif %}
{% endmacro %}

{%- if messages[0]["role"] == "system" %}
{%- set system_message = messages[0]["content"] %}
{%- set loop_messages = messages[1:] %}
{%- else %}
{%- set loop_messages = messages %}
{%- endif %}

{%- if not tools is defined %}
{%- set tools = [] %}
{%- endif %}

{%- if system_message is defined %}
{{- "<|im_start|>system\n" + system_message }}
{%- else %}
{%- if tools is iterable and tools | length > 0 %}
{{- "<|im_start|>system\nYou are Qwen, a helpful AI assistant that can interact with a computer to solve tasks." }}
{%- endif %}
{%- endif %}
{%- if tools is iterable and tools | length > 0 %}
{{- "\n\nYou have access to the following functions:\n\n" }}
{{- "<tools>" }}
{%- for tool in tools %}
{%- if tool.function is defined %}
{%- set tool = tool.function %}
{%- endif %}
{{- "\n<function>\n<name>" ~ tool.name ~ "</name>" }}
{{- '\n<description>' ~ (tool.description | trim) ~ '</description>' }}
{{- '\n<parameters>' }}
{%- for param_name, param_fields in tool.parameters.properties|items %}
{{- '\n<parameter>' }}
{{- '\n<name>' ~ param_name ~ '</name>' }}
{%- if param_fields.type is defined %}
{{- '\n<type>' ~ (param_fields.type | string) ~ '</type>' }}
{%- endif %}
{%- if param_fields.description is defined %}
{{- '\n<description>' ~ (param_fields.description | trim) ~ '</description>' }}
{%- endif %}
{{- render_item_list(param_fields.enum, 'enum') }}
{%- set handled_keys = ['type', 'description', 'enum', 'required'] %}
{%- for json_key in param_fields.keys() | reject("in", handled_keys) %}
{%- set normed_json_key = json_key | replace("-", "_") | replace(" ", "_") | replace("$", "") %}
{%- if param_fields[json_key] is mapping %}
{{- '\n<' ~ normed_json_key ~ '>' ~ (param_fields[json_key] | tojson | safe) ~ '</' ~ normed_json_key ~ '>' }}
{%- else %}
{{-'\n<' ~ normed_json_key ~ '>' ~ (param_fields[json_key] | string) ~ '</' ~ normed_json_key ~ '>' }}
{%- endif %}
{%- endfor %}
{{- render_item_list(param_fields.required, 'required') }}
{{- '\n</parameter>' }}
{%- endfor %}
{{- render_item_list(tool.parameters.required, 'required') }}
{{- '\n</parameters>' }}
{%- if tool.return is defined %}
{%- if tool.return is mapping %}
{{- '\n<return>' ~ (tool.return | tojson | safe) ~ '</return>' }}
{%- else %}
{{- '\n<return>' ~ (tool.return | string) ~ '</return>' }}
{%- endif %}
{%- endif %}
{{- '\n</function>' }}
{%- endfor %}
{{- "\n</tools>" }}
{{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n</IMPORTANT>' }}
{%- endif %}
{%- if system_message is defined %}
{{- '<|im_end|>\n' }}
{%- else %}
{%- if tools is iterable and tools | length > 0 %}
{{- '<|im_end|>\n' }}
{%- endif %}
{%- endif %}
{%- for message in loop_messages %}
{%- if message.role == "assistant" and message.tool_calls is defined and message.tool_calls is iterable and message.tool_calls | length > 0 %}
{{- '<|im_start|>' + message.role }}
{%- if message.content is defined and message.content is string and message.content | trim | length > 0 %}
{{- '\n' + message.content | trim + '\n' }}
{%- endif %}
{%- for tool_call in message.tool_calls %}
{%- if tool_call.function is defined %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{{- '\n<tool_call>\n<function=' + tool_call.name + '>\n' }}
{%- if tool_call.arguments is defined %}
{%- for args_name, args_value in tool_call.arguments|items %}
{{- '<parameter=' + args_name + '>\n' }}
{%- set args_value = args_value if args_value is string else args_value | string %}
{{- args_value }}
{{- '\n</parameter>\n' }}
{%- endfor %}
{%- endif %}
{{- '</function>\n</tool_call>' }}
{%- endfor %}
{{- '<|im_end|>\n' }}
{%- elif message.role == "user" or message.role == "system" or message.role == "assistant" %}
{{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
{%- elif message.role == "tool" %}
{%- if loop.previtem and loop.previtem.role != "tool" %}
{{- '<|im_start|>user\n' }}
{%- endif %}
{{- '<tool_response>\n' }}
{{- message.content }}
{{- '\n</tool_response>\n' }}
{%- if not loop.last and loop.nextitem.role != "tool" %}
{{- '<|im_end|>\n' }}
{%- elif loop.last %}
{{- '<|im_end|>\n' }}
{%- endif %}
{%- else %}
{{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' }}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|im_start|>assistant\n' }}
{%- endif %}
65 changes: 62 additions & 3 deletions tests/test-chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,65 @@ static void test_template_output_parsers() {
assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
}
{
auto tmpls = read_templates("models/templates/Qwen-Qwen3-Coder-30B-A3B-Instruct.jinja");
std::vector<std::string> end_tokens{ "<|im_end|>" };

assert_equals(COMMON_CHAT_FORMAT_QWEN3, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
assert_equals(COMMON_CHAT_FORMAT_QWEN3, common_chat_templates_apply(tmpls.get(), inputs_tools).format);

// Test parsing
assert_msg_equals(
simple_assist_msg("", "", "python", "{\"code\":\"print('Hello, World!')\"}"),
common_chat_parse(
"<tool_call>\n"
"<function=python>\n"
"<parameter=code>\n"
"\"print('Hello, World!')\"\n"
"</parameter>\n"
"</function>\n"
"</tool_call>\n",
/* is_partial= */ false,
{COMMON_CHAT_FORMAT_QWEN3}));
assert_msg_equals(
simple_assist_msg("Prelude", "", "python", "{\"code\":\"print('Hello, World!')\"}"),
common_chat_parse(
"Prelude"
"<tool_call>\n"
"<function=python>\n"
"<parameter=code>\n"
"\"print('Hello, World!')\"\n"
"</parameter>\n"
"</function>\n"
"</tool_call>\n",
/* is_partial= */ false,
{COMMON_CHAT_FORMAT_QWEN3}));
assert_msg_equals(
simple_assist_msg("Prelude", "Thoughts", "python", "{\"code\":\"print('Hello, World!')\"}"),
common_chat_parse(
"<think>Thoughts</think>Prelude"
"<tool_call>\n"
"<function=python>\n"
"<parameter=code>\n"
"\"print('Hello, World!')\"\n"
"</parameter>\n"
"</function>\n"
"</tool_call>\n",
/* is_partial= */ false,
{
/* .format = */ COMMON_CHAT_FORMAT_QWEN3,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
}));

test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
"<tool_call>\n"
"<function=special_function>\n"
"<parameter=arg1>\n"
"1\n"
"</parameter>\n"
"</function>\n"
"</tool_call>");
}
{
auto tmpls = read_templates("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja");
std::vector<std::string> end_tokens{ "<|im_end|>" };
Expand Down Expand Up @@ -1560,9 +1619,9 @@ int main(int argc, char ** argv) {
} else
#endif
{
test_msg_diffs_compute();
test_msgs_oaicompat_json_conversion();
test_tools_oaicompat_json_conversion();
// test_msg_diffs_compute();
// test_msgs_oaicompat_json_conversion();
// test_tools_oaicompat_json_conversion();
test_template_output_parsers();
std::cout << "\n[chat] All tests passed!" << '\n';
}
Expand Down
Loading
Loading