Skip to content

Commit 969368a

Browse files
committed
model : add llama 3.x nemotron reasoning/tool parsing
1 parent 5fd160b commit 969368a

File tree

8 files changed

+215
-1
lines changed

8 files changed

+215
-1
lines changed

common/chat-parser.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,19 @@ bool common_chat_msg_parser::try_consume_literal(const std::string & literal) {
106106
return true;
107107
}
108108

109+
bool common_chat_msg_parser::try_consume_partial_literal(const std::string & literal) {
110+
if (is_partial_) {
111+
auto idx = string_find_partial_stop(input_, literal);
112+
if (idx != std::string::npos && idx >= pos_) {
113+
auto end = input_.size();
114+
if (end < idx + literal.size()) {
115+
throw common_chat_msg_partial_exception(literal);
116+
}
117+
}
118+
}
119+
return try_consume_literal(literal);
120+
}
121+
109122
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_find_literal(const std::string & literal) {
110123
auto idx = input_.find(literal, pos_);
111124
if (idx != std::string::npos) {
@@ -153,7 +166,7 @@ bool common_chat_msg_parser::try_parse_reasoning(const std::string & start_think
153166
}
154167
};
155168
if (syntax_.reasoning_format != COMMON_REASONING_FORMAT_NONE) {
156-
if (syntax_.thinking_forced_open || try_consume_literal(start_think)) {
169+
if (syntax_.thinking_forced_open || try_consume_partial_literal(start_think)) {
157170
if (auto res = try_find_literal(end_think)) {
158171
handle_reasoning(res->prelude, /* closed */ true);
159172
consume_spaces();

common/chat-parser.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ class common_chat_msg_parser {
8282
std::optional<find_regex_result> try_find_regex(const common_regex & regex, size_t from = std::string::npos, bool add_prelude_to_content = true);
8383

8484
bool try_consume_literal(const std::string & literal);
85+
bool try_consume_partial_literal(const std::string & literal);
8586

8687
std::optional<find_regex_result> try_find_literal(const std::string & literal);
8788

common/chat.cpp

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,7 @@ const char * common_chat_format_name(common_chat_format format) {
600600
case COMMON_CHAT_FORMAT_MISTRAL_NEMO: return "Mistral Nemo";
601601
case COMMON_CHAT_FORMAT_LLAMA_3_X: return "Llama 3.x";
602602
case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: return "Llama 3.x with builtin tools";
603+
case COMMON_CHAT_FORMAT_LLAMA_3_X_NEMOTRON: return "Llama 3.x Nemotron";
603604
case COMMON_CHAT_FORMAT_DEEPSEEK_R1: return "DeepSeek R1";
604605
case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: return "FireFunction v2";
605606
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2";
@@ -1854,6 +1855,63 @@ static void common_chat_parse_granite(common_chat_msg_parser & builder) {
18541855
}
18551856
}
18561857

1858+
static common_chat_params common_chat_params_init_llama_3_x_nemotron(const common_chat_template & tmpl, const struct templates_params & inputs) {
1859+
common_chat_params data;
1860+
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
1861+
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
1862+
auto schemas = json::array();
1863+
foreach_function(inputs.tools, [&](const json & tool) {
1864+
const auto & function = tool.at("function");
1865+
schemas.push_back({
1866+
{"type", "object"},
1867+
{"properties", {
1868+
{"name", {
1869+
{"type", "string"},
1870+
{"const", function.at("name")},
1871+
}},
1872+
{"arguments", function.at("parameters")},
1873+
}},
1874+
{"required", json::array({"name", "arguments"})},
1875+
});
1876+
});
1877+
auto schema = json {
1878+
{"type", "array"},
1879+
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
1880+
{"minItems", 1},
1881+
};
1882+
if (!inputs.parallel_tool_calls) {
1883+
schema["maxItems"] = 1;
1884+
}
1885+
builder.add_rule("root", "\"<TOOLCALL>\" " + builder.add_schema("tool_calls", schema) + " \"</TOOLCALL>\"");
1886+
});
1887+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<TOOLCALL>"});
1888+
data.preserved_tokens = {
1889+
"<TOOLCALL>",
1890+
"</TOOLCALL>"
1891+
};
1892+
data.prompt = apply(tmpl, inputs);
1893+
data.format = COMMON_CHAT_FORMAT_LLAMA_3_X_NEMOTRON;
1894+
return data;
1895+
}
1896+
static void common_chat_parse_llama_3_x_nemotron(common_chat_msg_parser & builder) {
1897+
builder.try_parse_reasoning("<think>", "</think>");
1898+
if (!builder.syntax().parse_tool_calls) {
1899+
builder.add_content(builder.consume_rest());
1900+
return;
1901+
}
1902+
static const common_regex toolcall_regex("<TOOLCALL>");
1903+
static const common_regex close_regex("</TOOLCALL>");
1904+
static const std::vector<std::vector<std::string>> args_paths = {{"arguments"}};
1905+
if (builder.try_find_regex(toolcall_regex)) {
1906+
auto tool_calls = builder.consume_json_with_dumped_args(args_paths);
1907+
if (!builder.add_tool_calls(tool_calls.value) || tool_calls.is_partial) {
1908+
throw common_chat_msg_partial_exception("incomplete tool call array");
1909+
}
1910+
builder.consume_regex(close_regex);
1911+
}
1912+
builder.add_content(builder.consume_rest());
1913+
}
1914+
18571915
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
18581916
common_chat_params data;
18591917
data.prompt = apply(tmpl, inputs);
@@ -1968,6 +2026,11 @@ static common_chat_params common_chat_templates_apply_jinja(
19682026
return common_chat_params_init_llama_3_x(tmpl, params, allow_python_tag_builtin_tools);
19692027
}
19702028

2029+
// Llama 3.x Nemotron (w/ tools)
2030+
if (src.find("<TOOLCALL>") != std::string::npos) {
2031+
return common_chat_params_init_llama_3_x_nemotron(tmpl, params);
2032+
}
2033+
19712034
// Plain handler (no tools)
19722035
if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
19732036
return common_chat_params_init_without_tools(tmpl, params);
@@ -2074,6 +2137,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
20742137
case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS:
20752138
common_chat_parse_llama_3_1(builder, /* with_builtin_tools= */ true);
20762139
break;
2140+
case COMMON_CHAT_FORMAT_LLAMA_3_X_NEMOTRON:
2141+
common_chat_parse_llama_3_x_nemotron(builder);
2142+
break;
20772143
case COMMON_CHAT_FORMAT_DEEPSEEK_R1:
20782144
common_chat_parse_deepseek_r1(builder);
20792145
break;

common/chat.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ enum common_chat_format {
103103
COMMON_CHAT_FORMAT_MISTRAL_NEMO,
104104
COMMON_CHAT_FORMAT_LLAMA_3_X,
105105
COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
106+
COMMON_CHAT_FORMAT_LLAMA_3_X_NEMOTRON,
106107
COMMON_CHAT_FORMAT_DEEPSEEK_R1,
107108
COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
108109
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,

models/templates/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ These templates can be updated with the following commands:
1818
./scripts/get_chat_template.py mistralai/Mistral-Nemo-Instruct-2407 > models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja
1919
./scripts/get_chat_template.py NousResearch/Hermes-2-Pro-Llama-3-8B tool_use > models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja
2020
./scripts/get_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use > models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja
21+
./scripts/get_chat_template.py nvidia/Llama-3_3-Nemotron-Super-49B-v1_5 > models/templates/nvidia/nvidia-Llama-3_3-Nemotron-Super-49B-v1_5.jinja
2122
./scripts/get_chat_template.py Qwen/Qwen2.5-7B-Instruct > models/templates/Qwen-Qwen2.5-7B-Instruct.jinja
2223
./scripts/get_chat_template.py Qwen/QwQ-32B > models/templates/Qwen-QwQ-32B.jinja
2324
./scripts/get_chat_template.py Qwen/Qwen3-0.6B > models/templates/Qwen-Qwen3-0.6B.jinja
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
{% set bos = "<|begin_of_text|>" %}{%- set enable_thinking = true -%}{% set system_start_header = "<|start_header_id|>" %}{% set system_end_header = "<|end_header_id|>
2+
3+
" %}{% set start_header = "<|start_header_id|>" %}{% set end_header = "<|end_header_id|>
4+
5+
" %}{% set eot = "<|eot_id|>" %}{% set system_token = "system" %}{% set user_token = "user" %}{% set assistant_token = "assistant" %}{% set tool_token = "tool" %}{{- bos ~ system_start_header ~ system_token ~ system_end_header -}}{%- if messages[0].role == 'system' and messages[0].content != '' -%}{%- set system_content = messages[0].content -%}{%- if '/no_think' in system_content -%}{%- set system_content = system_content.replace('/no_think', '')|trim -%}{%- set enable_thinking = false -%}{%- elif '/think' in system_content -%}{%- set system_content = system_content.replace('/think', '')|trim -%}{%- set enable_thinking = true -%}{%- endif -%}{{- system_content + '
6+
7+
' -}}{%- endif -%}{%- if tools -%}{{- 'You can use the following tools to assist the user if required:
8+
<AVAILABLE_TOOLS>[' -}}{%- for tool in tools -%}{{- (tool.function if tool.function is defined else tool) | tojson -}}{{- ', ' if not loop.last else '' -}}{%- endfor -%}{{- ']</AVAILABLE_TOOLS>
9+
10+
If you decide to call any tool(s), use the following format:
11+
<TOOLCALL>[{{"name": "tool_name1", "arguments": "tool_args1"}}, {{"name": "tool_name2", "arguments": "tool_args2"}}]</TOOLCALL>
12+
13+
Response from tool(s) will be returned in this format:
14+
<TOOL_RESPONSE>[{{"response": "tool_response1"}}, {{"response": "tool_response2"}}]</TOOL_RESPONSE>
15+
16+
Based on the results returned by the tool(s), you can call additional tools if needed, correct tool calls if any errors are found, or just respond with the answer to the user.' -}}{%- endif -%}{{- eot -}}{%- for message in messages -%}{%- if message.role == user_token -%}{{- start_header ~ user_token ~ end_header -}}{{ message.content -}}{{ eot -}}{%- elif message.role == assistant_token -%}{%- if '</think>' in message.content -%}{%- set content = message.content.split('</think>')[-1].lstrip() -%}{%- else -%}{%- set content = message.content -%}{%- endif -%}{{- start_header ~ assistant_token ~ end_header -}}{{ content -}}{%- if message.tool_calls -%}{{- '<TOOLCALL>[' -}}{%- for call in message.tool_calls -%}{%- set fn = call.function if call.function is defined else call -%}{{- '{"name": "' + fn.name + '", "arguments": ' -}}{%- if fn.arguments is string -%}{{- fn.arguments -}}{%- else -%}{{- fn.arguments | tojson -}}{%- endif -%}{{- '}' + (', ' if not loop.last else '') -}}{%- endfor -%}{{- ']</TOOLCALL>' -}}{%- endif -%}{{- eot -}}{%- elif message.role == tool_token -%}{%- if loop.first or (messages[loop.index0 - 1].role != tool_token) -%}{{- start_header ~ tool_token ~ end_header -}}{{ '<TOOL_RESPONSE>[' -}}{%- endif -%}{{- message.content -}}{{- ', ' if not loop.last and (messages[loop.index0 + 1].role == tool_token) else '' -}}{%- if loop.last or (messages[loop.index0 + 1].role != tool_token) -%}{{- ']</TOOL_RESPONSE>' -}}{{ eot -}}{%- endif -%}{%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%}{{- start_header ~ assistant_token ~ end_header -}}{%- if not enable_thinking -%}{{- '<think>
17+
18+
</think>
19+
20+
' -}}{%- endif -%}{%- endif -%}

tests/test-chat-parser.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,47 @@ static void test_reasoning() {
9999
assert_equals("<think>Cogito</think>", builder.result().content);
100100
assert_equals("Ergo sum", builder.consume_rest());
101101
}
102+
{
103+
common_chat_msg_parser builder("<tnk>Cogito", /* is_partial= */ true, {
104+
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
105+
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
106+
/* .reasoning_in_content = */ false,
107+
/* .thinking_forced_open = */ false,
108+
});
109+
110+
assert_equals(true, builder.try_parse_reasoning("<tnk>", "</tnk>"));
111+
assert_equals("Cogito", builder.result().reasoning_content);
112+
assert_equals("", builder.consume_rest());
113+
}
114+
{
115+
common_chat_msg_parser builder("<t", /* is_partial= */ true, {
116+
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
117+
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
118+
/* .reasoning_in_content = */ false,
119+
/* .thinking_forced_open = */ false,
120+
});
121+
122+
try {
123+
builder.try_parse_reasoning("<tnk>", "</tnk>");
124+
throw std::runtime_error("Expected exception");
125+
} catch (const std::exception & e) {
126+
if (std::string(e.what()).find("<tnk>") == std::string::npos) {
127+
throw std::runtime_error("Expected exception about partial <tnk>");
128+
}
129+
}
130+
}
131+
{
132+
common_chat_msg_parser builder("<think>Cogito", /* is_partial= */ true, {
133+
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
134+
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
135+
/* .reasoning_in_content = */ false,
136+
/* .thinking_forced_open = */ false,
137+
});
138+
139+
assert_equals(false, builder.try_parse_reasoning("<tnk>", "</tnk>"));
140+
assert_equals("", builder.result().reasoning_content);
141+
assert_equals("<think>Cogito", builder.consume_rest());
142+
}
102143
}
103144

104145
static void test_regex() {

tests/test-chat.cpp

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1439,6 +1439,77 @@ static void test_template_output_parsers() {
14391439
/* expect_grammar_triggered= */ false
14401440
);
14411441
}
1442+
{
1443+
auto tmpls = read_templates("models/templates/nvidia-Llama-3_3-Nemotron-Super-49B-v1_5.jinja");
1444+
std::vector<std::string> end_tokens{ "<|eot_id|>" };
1445+
1446+
assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_NEMOTRON, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
1447+
1448+
test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
1449+
test_templates(tmpls.get(), end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
1450+
1451+
assert_msg_equals(message_assist_thoughts_unparsed_deepseek,
1452+
common_chat_parse(
1453+
"<think>I'm\nthinking</think>Hello, world!\nWhat's up?",
1454+
/* is_partial= */ false,
1455+
{COMMON_CHAT_FORMAT_LLAMA_3_X_NEMOTRON}));
1456+
assert_msg_equals(message_assist_thoughts,
1457+
common_chat_parse(
1458+
"<think>I'm\nthinking</think>Hello, world!\nWhat's up?",
1459+
/* is_partial= */ false,
1460+
{
1461+
/* .format = */ COMMON_CHAT_FORMAT_LLAMA_3_X_NEMOTRON,
1462+
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
1463+
}));
1464+
assert_msg_equals(message_assist_thoughts,
1465+
common_chat_parse(
1466+
"I'm\nthinking</think>Hello, world!\nWhat's up?",
1467+
/* is_partial= */ false,
1468+
{
1469+
/* .format = */ COMMON_CHAT_FORMAT_LLAMA_3_X_NEMOTRON,
1470+
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
1471+
/* .reasoning_in_content = */ false,
1472+
/* .thinking_forced_open = */ true,
1473+
}));
1474+
1475+
assert_msg_equals(message_assist_call_thoughts_unparsed,
1476+
common_chat_parse(
1477+
"<think>I'm\nthinking</think>\n\n"
1478+
"<TOOLCALL>[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]</TOOLCALL>",
1479+
/* is_partial= */ false,
1480+
{COMMON_CHAT_FORMAT_LLAMA_3_X_NEMOTRON}));
1481+
assert_msg_equals(message_assist_call,
1482+
common_chat_parse(
1483+
"<TOOLCALL>[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]</TOOLCALL>",
1484+
/* is_partial= */ false,
1485+
{COMMON_CHAT_FORMAT_LLAMA_3_X_NEMOTRON}));
1486+
assert_msg_equals(message_assist_call_thoughts,
1487+
common_chat_parse(
1488+
"<think>I'm\nthinking</think>\n\n"
1489+
"<TOOLCALL>[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]</TOOLCALL>",
1490+
/* is_partial= */ false,
1491+
{
1492+
/* .format = */ COMMON_CHAT_FORMAT_LLAMA_3_X_NEMOTRON,
1493+
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
1494+
}));
1495+
1496+
assert_msg_equals(message_assist_empty,
1497+
common_chat_parse(
1498+
"<th",
1499+
/* is_partial= */ true,
1500+
{
1501+
/* .format = */ COMMON_CHAT_FORMAT_LLAMA_3_X_NEMOTRON,
1502+
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
1503+
}));
1504+
assert_msg_equals(message_assist_thoughts_no_content,
1505+
common_chat_parse(
1506+
"<think>I'm\nthinking",
1507+
/* is_partial= */ true,
1508+
{
1509+
/* .format = */ COMMON_CHAT_FORMAT_LLAMA_3_X_NEMOTRON,
1510+
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
1511+
}));
1512+
}
14421513
}
14431514

14441515
static void test_msg_diffs_compute() {

0 commit comments

Comments
 (0)