Skip to content
Merged
115 changes: 115 additions & 0 deletions common/chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,7 @@ const char * common_chat_format_name(common_chat_format format) {
case COMMON_CHAT_FORMAT_GRANITE: return "Granite";
case COMMON_CHAT_FORMAT_GPT_OSS: return "GPT-OSS";
case COMMON_CHAT_FORMAT_SEED_OSS: return "Seed-OSS";
case COMMON_CHAT_FORMAT_NEMOTRON_V2: return "Nemotron V2";
default:
throw std::runtime_error("Unknown chat format");
}
Expand Down Expand Up @@ -1184,6 +1185,84 @@ static common_chat_params common_chat_params_init_llama_3_x(const common_chat_te
});
return data;
}

static common_chat_params common_chat_params_init_nemotron_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;

// Generate the prompt using the apply() function with the template
data.prompt = apply(tmpl, inputs);
data.format = COMMON_CHAT_FORMAT_NEMOTRON_V2;

// Handle thinking tags appropriately based on inputs.enable_thinking
if (string_ends_with(data.prompt, "<think>\n")) {
if (!inputs.enable_thinking) {
data.prompt += "</think>";
} else {
data.thinking_forced_open = true;
}
}

// When tools are present, build grammar for the <TOOLCALL> format
if (!inputs.tools.is_null() && inputs.tools.is_array() && !inputs.tools.empty()) {
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
std::string name = function.at("name");
auto parameters = function.at("parameters");
builder.resolve_refs(parameters);

// Build tool call rule for Nemotron format
tool_rules.push_back(builder.add_rule(name + "-call", builder.add_schema(name + "-args", {
{"type", "object"},
{"properties", {
{"name", {{"const", name}}},
{"arguments", parameters},
}},
{"required", json::array({"name", "arguments"})},
})));
});

// Create tool call rule - Nemotron uses <TOOLCALL>[...] format
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | "));
auto tool_list = builder.add_rule("tool_list", "\"[\" space " + tool_call + " (\",\" space " + tool_call + ")* space \"]\"");

// Grammar rule for <TOOLCALL> format
builder.add_rule("root",
std::string(data.thinking_forced_open ? "( \"\" space )? " : "") +
"\"<TOOLCALL>\" space " + tool_list + " \"</TOOLCALL>\"");
});

// Define triggers for <TOOLCALL> token
data.grammar_triggers.push_back({
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
"<TOOLCALL>"
});
} else {
// Handle thinking tags for non-tool responses
if (data.thinking_forced_open && inputs.enable_thinking) {
data.grammar_lazy = false;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
builder.add_rule("root", "\"</think>\" space");
});
}
}

// Set preserved tokens for all special tokens
data.preserved_tokens = {
"<think>",
"</think>",
"<TOOLCALL>",
"</TOOLCALL>",
"<TOOL_RESPONSE>",
"</TOOL_RESPONSE>",
"<AVAILABLE_TOOLS>",
"</AVAILABLE_TOOLS>"
};

return data;
}
static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
Expand Down Expand Up @@ -2060,6 +2139,34 @@ static void common_chat_parse_granite(common_chat_msg_parser & builder) {
}
}

static void common_chat_parse_nemotron_v2(common_chat_msg_parser & builder) {
// Parse thinking tags
builder.try_parse_reasoning("<think>", "</think>");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}

// Look for tool calls
static const common_regex tool_call_regex(regex_escape("<TOOLCALL>"));
if (auto res = builder.try_find_regex(tool_call_regex)) {
builder.move_to(res->groups[0].end);

// Expect JSON array of tool calls
auto tool_calls_data = builder.consume_json();
if (tool_calls_data.json.is_array()) {
if (!builder.try_consume_literal("</TOOLCALL>")) {
throw common_chat_msg_partial_exception("Incomplete tool call");
}
builder.add_tool_calls(tool_calls_data.json);
} else {
throw common_chat_msg_partial_exception("Incomplete tool call");
}
} else {
builder.add_content(builder.consume_rest());
}
}

static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) {
// Parse thinking tags first - this handles the main reasoning content
builder.try_parse_reasoning("<seed:think>", "</seed:think>");
Expand Down Expand Up @@ -2293,6 +2400,11 @@ static common_chat_params common_chat_templates_apply_jinja(
return common_chat_params_init_seed_oss(tmpl, params, inputs);
}

// Nemotron v2
if (src.find("<SPECIAL_10>") != std::string::npos) {
return common_chat_params_init_nemotron_v2(tmpl, params);
}

// Use generic handler when mixing tools + JSON schema.
// TODO: support that mix in handlers below.
if ((params.tools.is_array() && params.json_schema.is_object())) {
Expand Down Expand Up @@ -2454,6 +2566,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
case COMMON_CHAT_FORMAT_SEED_OSS:
common_chat_parse_seed_oss(builder);
break;
case COMMON_CHAT_FORMAT_NEMOTRON_V2:
common_chat_parse_nemotron_v2(builder);
break;
default:
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
}
Expand Down
1 change: 1 addition & 0 deletions common/chat.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ enum common_chat_format {
COMMON_CHAT_FORMAT_GRANITE,
COMMON_CHAT_FORMAT_GPT_OSS,
COMMON_CHAT_FORMAT_SEED_OSS,
COMMON_CHAT_FORMAT_NEMOTRON_V2,

COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
};
Expand Down
32 changes: 32 additions & 0 deletions models/templates/NVIDIA-Nemotron-Nano-v2.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
{%- set ns = namespace(enable_thinking=true) %}{%- for message in messages -%}{%- set content = message['content'] -%}{%- if message['role'] == 'user' or message['role'] == 'system' -%}{%- if '/think' in content -%}{%- set ns.enable_thinking = true -%}{%- elif '/no_think' in content -%}{%- set ns.enable_thinking = false -%}{%- endif -%}{%- endif -%}{%- endfor -%}{%- if messages[0]['role'] != 'system' -%}{%- set ns.non_tool_system_content = '' -%}{{- '<SPECIAL_10>System
' -}}{%- else -%}{%- set ns.non_tool_system_content = messages[0]['content'].replace('/think', '').replace('/no_think', '').strip() -%}{{- '<SPECIAL_10>System
' + ns.non_tool_system_content }}{%- endif -%}{%- if tools -%}{%- if ns.non_tool_system_content is defined and ns.non_tool_system_content != '' -%}{{- '

' -}}{%- endif -%}{{- 'You can use the following tools to assist the user if required:' -}}{{- '
<AVAILABLE_TOOLS>[' -}}{%- for tool in tools -%}{{- (tool.function if tool.function is defined else tool) | tojson -}}{{- ', ' if not loop.last else '' -}}{%- endfor -%}{{- ']</AVAILABLE_TOOLS>

' -}}{{- 'If you decide to call any tool(s), use the following format:
' -}}{{- '<TOOLCALL>[{{"name": "tool_name1", "arguments": "tool_args1"}}, ' -}}{{- '{{"name": "tool_name2", "arguments": "tool_args2"}}]</TOOLCALL>

' -}}{{- 'The user will execute tool-calls and return responses from tool(s) in this format:
' -}}{{- '<TOOL_RESPONSE>[{{"tool_response1"}}, {{"tool_response2"}}]</TOOL_RESPONSE>

' -}}{{- 'Based on the tool responses, you can call additional tools if needed, correct tool calls if any errors are found, or just respond to the user.' -}}{%- endif -%}{{- '
' -}}{%- set messages = messages[1:] if messages[0]['role'] == 'system' else messages -%}{%- if messages[-1]['role'] == 'assistant' -%}{%- set ns.last_turn_assistant_content = messages[-1]['content'].strip() -%}{%- set messages = messages[:-1] -%}{%- endif -%}{%- for message in messages %}{%- set content = message['content'] %}{%- if message['role'] == 'user' -%}{{- '<SPECIAL_11>User
' + content.replace('/think', '').replace('/no_think', '').strip() + '
' }}{%- elif message['role'] == 'tool' -%}{%- if loop.first or (messages[loop.index0 - 1].role != 'tool') -%}{{- '<SPECIAL_11>User
' + '<TOOL_RESPONSE>[' }}{%- endif -%}{{- message['content'] -}}{{- ', ' if not loop.last and (messages[loop.index0 + 1].role == 'tool') else '' -}}{%- if loop.last or (messages[loop.index0 + 1].role != 'tool') -%}{{- ']</TOOL_RESPONSE>
' -}}{%- endif -%}{%- elif message['role'] == 'assistant' -%}{%- if '</think>' in content -%}{%- set content = content.split('</think>')[1].strip() %}{%- endif -%}{{- '<SPECIAL_11>Assistant
' + content.strip() }}{%- if message.tool_calls -%}{%- if content.strip() != '' -%}{{- '

' -}}{%- endif -%}{{- '<TOOLCALL>[' -}}{%- for call in message.tool_calls -%}{%- set fn = call.function if call.function is defined else call -%}{{- '{"name": "' + fn.name + '", "arguments": ' -}}{%- if fn.arguments is string -%}{{- fn.arguments -}}{%- else -%}{{- fn.arguments | tojson -}}{%- endif -%}{{- '}' + (', ' if not loop.last else '') -}}{%- endfor -%}{{- ']</TOOLCALL>' -}}{%- endif -%}{{- '
<SPECIAL_12>
' -}}{%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%}{{- '<SPECIAL_11>Assistant
' -}}{%- if ns.enable_thinking is defined and ns.enable_thinking is false -%}{{- '<think></think>' -}}{%- else -%}{{- '<think>
' -}}{%- endif -%}{%- if ns.last_turn_assistant_content is defined and ns.last_turn_assistant_content != '' -%}{{- ns.last_turn_assistant_content -}}{%- endif -%}{%- else -%}{%- if ns.last_turn_assistant_content is defined and ns.last_turn_assistant_content != '' -%}{{- '<SPECIAL_11>Assistant
' -}}{%- if ns.enable_thinking is defined and ns.enable_thinking is false -%}{{- '<think></think>' -}}{%- else -%}{{- '<think>
' -}}{%- endif -%}{{- ns.last_turn_assistant_content -}}{%- if continue_final_message is defined -%}{%- if continue_final_message is false -%}{{- '
<SPECIAL_12>
' -}}{%- endif -%}{%- else -%}{{- '
<SPECIAL_12>
' -}}{%- endif -%}{%- endif -%}{%- endif -%}
53 changes: 53 additions & 0 deletions tests/test-chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1755,6 +1755,59 @@ static void test_template_output_parsers() {
/* is_partial= */ false,
{COMMON_CHAT_FORMAT_SEED_OSS}));
}

{
auto tmpls = read_templates("models/templates/NVIDIA-Nemotron-Nano-v2.jinja");
std::vector<std::string> end_tokens{ "<SPECIAL_12>" };

assert_equals(COMMON_CHAT_FORMAT_NEMOTRON_V2, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
assert_equals(COMMON_CHAT_FORMAT_NEMOTRON_V2, common_chat_templates_apply(tmpls.get(), inputs_tools).format);

// Test parsing regular content
assert_msg_equals(message_assist,
common_chat_parse(
"Hello, world!\nWhat's up?",
/* is_partial= */ false,
{COMMON_CHAT_FORMAT_NEMOTRON_V2}));

// Test parsing content with thinking
assert_msg_equals(message_assist_thoughts,
common_chat_parse(
"<think>I'm\nthinking</think>Hello, world!\nWhat's up?",
/* is_partial= */ false,
{
/* .format = */ COMMON_CHAT_FORMAT_NEMOTRON_V2,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
}));

// Test parsing tool calls
assert_msg_equals(message_assist_call,
common_chat_parse(
"<TOOLCALL>[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]</TOOLCALL>",
/* is_partial= */ false,
{COMMON_CHAT_FORMAT_NEMOTRON_V2}));

// Test template generation for regular content
test_templates(tmpls.get(), end_tokens, message_assist, tools,
"Hello, world!\nWhat's up?\n",
/* expect_grammar_triggered= */ false);

// Test template generation for tool calls
test_templates(tmpls.get(), end_tokens, message_assist_call_id, tools,
"{\n"
" \"tool_calls\": [\n"
" {\n"
" \"name\": \"special_function\",\n"
" \"arguments\": {\n"
" \"arg1\": 1\n"
" },\n"
" \"id\": \"123456789\"\n"
" }\n"
" ]\n"
"}\n",
/* expect_grammar_triggered= */ false
);
}
}

static void test_msg_diffs_compute() {
Expand Down
Loading