Skip to content

Commit cf9a0d6

Browse files
committed
gpt-oss : add support for recipient in role header
1 parent da67163 commit cf9a0d6

File tree

2 files changed

+126
-35
lines changed

2 files changed

+126
-35
lines changed

common/chat.cpp

Lines changed: 87 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1318,24 +1318,60 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
13181318
if (inputs.tools.is_array() && !inputs.tools.empty()) {
13191319
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
13201320
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
1321-
std::vector<std::string> tool_rules;
1321+
// tool calls can appear in commentary or analysis channels
1322+
auto channel = builder.add_rule("channel", "\"<|channel|>\" ( \"commentary\" | \"analysis\" )");
1323+
1324+
std::vector<std::string> tool_rules_recipient_in_role;
1325+
std::vector<std::string> tool_rules_recipient_in_channel;
13221326
foreach_function(inputs.tools, [&](const json & tool) {
13231327
const auto & function = tool.at("function");
13241328
std::string name = function.at("name");
13251329
auto parameters = function.at("parameters");
13261330
builder.resolve_refs(parameters);
13271331

1328-
tool_rules.push_back(builder.add_rule(name + "-call",
1329-
"\"" + name + "\"" + " space \"<|constrain|>\"? \"json\" space \"<|message|>\" " + builder.add_schema(name + "-args", parameters)
1330-
));
1332+
tool_rules_recipient_in_role.push_back(
1333+
builder.add_rule(name + "-call",
1334+
"\"" + name + "\"" + channel + " \" <|constrain|>json\"? \"<|message|>\" " +
1335+
builder.add_schema(name + "-args", parameters)
1336+
)
1337+
);
1338+
1339+
tool_rules_recipient_in_channel.push_back(
1340+
builder.add_rule(name + "-call",
1341+
"\"" + name + "\"" + " \" <|constrain|>json\"? \"<|message|>\" " +
1342+
builder.add_schema(name + "-args", parameters)
1343+
)
1344+
);
13311345
});
13321346

1333-
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | "));
1334-
builder.add_rule("root", "\"<|channel|>commentary to=functions.\" " + tool_call);
1347+
auto recipient_in_role = builder.add_rule("recipient_in_role",
1348+
"\"<|start|>assistant\"? \" to=functions.\" " +
1349+
string_join(tool_rules_recipient_in_role, " | ")
1350+
);
13351351

1352+
auto recipient_in_channel = builder.add_rule("recipient_in_channel",
1353+
channel + " \" to=functions.\" " +
1354+
string_join(tool_rules_recipient_in_channel, " | ")
1355+
);
1356+
1357+
builder.add_rule("root", recipient_in_role + " | " + recipient_in_channel);
1358+
1359+
// Trigger on tool calls that appear in the commentary channel
13361360
data.grammar_triggers.push_back({
13371361
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
1338-
"<\\|channel\\|>commentary to"
1362+
"<\\|channel\\|>(commentary|analysis) to"
1363+
});
1364+
1365+
// Trigger tool calls that appear in the role section, either at the
1366+
// start or in the middle.
1367+
data.grammar_triggers.push_back({
1368+
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
1369+
"^ to"
1370+
});
1371+
1372+
data.grammar_triggers.push_back({
1373+
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
1374+
"<\\|start\\|>assistant to"
13391375
});
13401376

13411377
data.preserved_tokens = {
@@ -1353,12 +1389,12 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
13531389
static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) {
13541390
static const common_regex message_regex("<\\|message\\|>");
13551391
static const common_regex channel_regex("<\\|channel\\|>(final|analysis|commentary)");
1392+
static const common_regex tool_call_channel_regex("<\\|channel\\|>(commentary|analysis)");
13561393
static const common_regex start_regex("<\\|start\\|>assistant");
13571394
static const common_regex end_regex("<\\|end\\|>");
13581395
static const common_regex to_regex(" to=");
1359-
static const common_regex user_tool_call_regex(
1360-
"functions\\.([a-zA-Z_][a-zA-Z0-9_]*)\\s*(?:(?:<\\|constrain\\|>)?([a-zA-Z]+))?\\s*<\\|message\\|>"
1361-
);
1396+
static const common_regex function_regex("functions\\.([a-zA-Z_][a-zA-Z0-9_]*)");
1397+
static const common_regex user_tool_call_regex("(?: <\\|constrain\\|>([a-zA-Z]+))?<\\|message\\|>");
13621398
static const common_regex builtin_tool_call_regex("(?:browser|python)[\\s\\S]*<\\|message\\|>");
13631399

13641400
// Save the start of the message so we can roll back when we encounter a tool call and parse_tool_calls == false.
@@ -1386,40 +1422,51 @@ static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) {
13861422
return false;
13871423
};
13881424

1389-
auto tool_call = [&]() {
1425+
auto tool_call = [&](bool recipient_in_role) {
13901426
if (!builder.syntax().parse_tool_calls) {
13911427
// Move back to the start and consume up to the next message
13921428
builder.move_to(message_start_pos);
13931429
builder.add_content(consume_until_next(message_start_pos + 1));
13941430
return;
13951431
}
13961432

1397-
if (auto res = builder.try_consume_regex(user_tool_call_regex)) {
1433+
if (auto res = builder.try_consume_regex(function_regex)) {
13981434
auto name = builder.str(res->groups[1]);
1399-
if (auto args = builder.try_consume_json_with_dumped_args({{}})) {
1400-
if (!builder.add_tool_call(name, "", args->value) || args->is_partial) {
1401-
throw common_chat_msg_partial_exception("incomplete tool call");
1435+
1436+
if (recipient_in_role) {
1437+
if (!builder.try_consume_regex(tool_call_channel_regex)) {
1438+
throw common_chat_msg_parse_exception("expected <|channel|>(commentary|analysis), got: " + consume_until_next());
1439+
}
1440+
}
1441+
1442+
if (builder.try_consume_regex(user_tool_call_regex)) {
1443+
if (auto args = builder.try_consume_json_with_dumped_args({{}})) {
1444+
if (!builder.add_tool_call(name, "", args->value) || args->is_partial) {
1445+
throw common_chat_msg_partial_exception("incomplete tool call");
1446+
}
14021447
}
1448+
} else {
1449+
throw common_chat_msg_parse_exception("expected function args, got: " + consume_until_next());
14031450
}
14041451
} else if (builder.try_consume_regex(builtin_tool_call_regex)) {
14051452
builder.consume_rest();
14061453
LOG_ERR("builtin tool calls not implemented\n");
14071454
} else {
1408-
throw common_chat_msg_parse_exception("expected function call, got: " + consume_until_next());
1455+
throw common_chat_msg_parse_exception("expected function name, got: " + consume_until_next());
14091456
}
14101457
};
14111458

14121459
auto commentary = [&]() {
14131460
if (builder.try_consume_regex(to_regex)) {
1414-
tool_call();
1461+
tool_call(false);
14151462
} else if (!try_consume_message()) {
14161463
throw common_chat_msg_parse_exception("expected: \" to=\" or <|message|>, got: " + consume_until_next());
14171464
}
14181465
};
14191466

14201467
auto analysis = [&]() {
14211468
if (builder.try_consume_regex(to_regex)) {
1422-
tool_call(); // built-in tools can be called in the analysis channel
1469+
tool_call(false); // built-in tools can be called in the analysis channel
14231470
} else if (builder.try_consume_regex(message_regex)) {
14241471
// Defer reasoning parsing to builder
14251472
builder.move_to(channel_start_pos);
@@ -1434,26 +1481,34 @@ static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) {
14341481
}
14351482
};
14361483

1437-
auto channel = [&]() {
1484+
auto channel = [&](const common_chat_msg_parser::find_regex_result & match) {
1485+
auto type = builder.str(match.groups[1]);
1486+
if (type == "analysis") {
1487+
analysis();
1488+
} else if (type == "commentary") {
1489+
commentary();
1490+
} else if (type == "final") {
1491+
if (!try_consume_message()) {
1492+
throw common_chat_msg_parse_exception("expected: <|message|>, got: " + consume_until_next());
1493+
}
1494+
} else {
1495+
throw common_chat_msg_parse_exception("expected one of: [analysis, commentary, final], got: " + consume_until_next());
1496+
}
1497+
};
1498+
1499+
auto message = [&]() {
14381500
if (auto res = builder.try_consume_regex(channel_regex)) {
14391501
channel_start_pos = res->groups[0].begin;
1440-
auto type = builder.str(res->groups[1]);
1441-
if (type == "analysis") {
1442-
analysis();
1443-
} else if (type == "commentary") {
1444-
commentary();
1445-
} else if (type == "final") {
1446-
if (!try_consume_message()) {
1447-
throw common_chat_msg_parse_exception("expected: <|message|>, got: " + consume_until_next());
1448-
}
1449-
}
1502+
channel(*res);
1503+
} else if (builder.try_consume_regex(to_regex)) {
1504+
tool_call(true);
14501505
} else {
1451-
throw common_chat_msg_parse_exception("expected: <|channel|>, got: " + consume_until_next());
1506+
throw common_chat_msg_parse_exception("expected: <|channel|> or \" to\", got: " + consume_until_next());
14521507
}
14531508
};
14541509

14551510
try {
1456-
channel();
1511+
message();
14571512
} catch (const common_chat_msg_parse_exception & e) {
14581513
LOG_DBG("Parse error: %s\n", e.what());
14591514
}
@@ -1462,7 +1517,7 @@ static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) {
14621517
while (auto res = builder.try_consume_regex(start_regex)) {
14631518
message_start_pos = res->groups[0].begin;
14641519
try {
1465-
channel();
1520+
message();
14661521
} catch (const common_chat_msg_parse_exception & e) {
14671522
LOG_DBG("Parse error: %s\n", e.what());
14681523
}

tests/test-chat.cpp

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1615,7 +1615,7 @@ static void test_template_output_parsers() {
16151615
assert_msg_equals(simple_assist_msg("", "I'm\nthinking", "special_function", "{\"arg1"),
16161616
common_chat_parse(
16171617
"<|channel|>analysis<|message|>I'm\nthinking<|end|>"
1618-
"<|start|>assistant<|channel|>commentary to=functions.special_function <|message|>{\"arg1",
1618+
"<|start|>assistant<|channel|>commentary to=functions.special_function<|message|>{\"arg1",
16191619
/* is_partial= */ true,
16201620
{
16211621
/* .format = */ COMMON_CHAT_FORMAT_GPT_OSS,
@@ -1630,6 +1630,15 @@ static void test_template_output_parsers() {
16301630
/* .format = */ COMMON_CHAT_FORMAT_GPT_OSS,
16311631
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
16321632
}));
1633+
assert_msg_equals(simple_assist_msg("", "I'm\nthinking", "special_function", "{\"arg1\": 1}"),
1634+
common_chat_parse(
1635+
"<|channel|>analysis<|message|>I'm\nthinking<|end|>"
1636+
"<|start|>assistant<|channel|>analysis to=functions.special_function <|constrain|>json<|message|>{\"arg1\": 1}",
1637+
/* is_partial= */ false,
1638+
{
1639+
/* .format = */ COMMON_CHAT_FORMAT_GPT_OSS,
1640+
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
1641+
}));
16331642
assert_msg_equals(simple_assist_msg("Hello, world!\nWhat's up?", "I'm\nthinking"),
16341643
common_chat_parse(
16351644
"<|channel|>analysis<|message|>I'm\nthinking<|end|>"
@@ -1666,11 +1675,11 @@ static void test_template_output_parsers() {
16661675
}));
16671676
assert_msg_equals(
16681677
simple_assist_msg(
1669-
"<|start|>assistant<|channel|>commentary to=functions.special_function <|message|>{\"arg1",
1678+
"<|start|>assistant<|channel|>commentary to=functions.special_function<|message|>{\"arg1",
16701679
"I'm\nthinking"),
16711680
common_chat_parse(
16721681
"<|channel|>analysis<|message|>I'm\nthinking<|end|>"
1673-
"<|start|>assistant<|channel|>commentary to=functions.special_function <|message|>{\"arg1",
1682+
"<|start|>assistant<|channel|>commentary to=functions.special_function<|message|>{\"arg1",
16741683
/* is_partial= */ true,
16751684
{
16761685
/* .format = */ COMMON_CHAT_FORMAT_GPT_OSS,
@@ -1732,6 +1741,33 @@ static void test_template_output_parsers() {
17321741
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
17331742
/* .reasoning_in_content = */ true,
17341743
}));
1744+
1745+
// Test tool calling in role header
1746+
assert_msg_equals(simple_assist_msg("", "", "special_function", "{\"arg1\": 1}"),
1747+
common_chat_parse(
1748+
" to=functions.special_function<|channel|>commentary <|constrain|>json<|message|>{\"arg1\": 1}",
1749+
/* is_partial= */ false,
1750+
{
1751+
/* .format = */ COMMON_CHAT_FORMAT_GPT_OSS,
1752+
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
1753+
}));
1754+
assert_msg_equals(simple_assist_msg("", "", "special_function", "{\"arg1\": 1}"),
1755+
common_chat_parse(
1756+
" to=functions.special_function<|channel|>analysis <|constrain|>json<|message|>{\"arg1\": 1}",
1757+
/* is_partial= */ false,
1758+
{
1759+
/* .format = */ COMMON_CHAT_FORMAT_GPT_OSS,
1760+
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
1761+
}));
1762+
assert_msg_equals(simple_assist_msg("", "I'm\nthinking", "special_function", "{\"arg1\": 1}"),
1763+
common_chat_parse(
1764+
"<|channel|>analysis<|message|>I'm\nthinking<|end|>"
1765+
"<|start|>assistant to=functions.special_function<|channel|>analysis <|constrain|>json<|message|>{\"arg1\": 1}",
1766+
/* is_partial= */ false,
1767+
{
1768+
/* .format = */ COMMON_CHAT_FORMAT_GPT_OSS,
1769+
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
1770+
}));
17351771
}
17361772
}
17371773

0 commit comments

Comments
 (0)