Skip to content

Commit cfb2f23

Browse files
committed
gpt-oss : add support for recipient in role header
1 parent a9ad8f6 commit cfb2f23

File tree

2 files changed

+126
-35
lines changed

2 files changed

+126
-35
lines changed

common/chat.cpp

Lines changed: 87 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1342,24 +1342,60 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
13421342
if (inputs.tools.is_array() && !inputs.tools.empty()) {
13431343
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
13441344
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
1345-
std::vector<std::string> tool_rules;
1345+
// tool calls can appear in commentary or analysis channels
1346+
auto channel = builder.add_rule("channel", "\"<|channel|>\" ( \"commentary\" | \"analysis\" )");
1347+
1348+
std::vector<std::string> tool_rules_recipient_in_role;
1349+
std::vector<std::string> tool_rules_recipient_in_channel;
13461350
foreach_function(inputs.tools, [&](const json & tool) {
13471351
const auto & function = tool.at("function");
13481352
std::string name = function.at("name");
13491353
auto parameters = function.at("parameters");
13501354
builder.resolve_refs(parameters);
13511355

1352-
tool_rules.push_back(builder.add_rule(name + "-call",
1353-
"\"" + name + "\"" + " space \"<|constrain|>\"? \"json\" space \"<|message|>\" " + builder.add_schema(name + "-args", parameters)
1354-
));
1356+
tool_rules_recipient_in_role.push_back(
1357+
builder.add_rule(name + "-call",
1358+
"\"" + name + "\"" + channel + " \" <|constrain|>json\"? \"<|message|>\" " +
1359+
builder.add_schema(name + "-args", parameters)
1360+
)
1361+
);
1362+
1363+
tool_rules_recipient_in_channel.push_back(
1364+
builder.add_rule(name + "-call",
1365+
"\"" + name + "\"" + " \" <|constrain|>json\"? \"<|message|>\" " +
1366+
builder.add_schema(name + "-args", parameters)
1367+
)
1368+
);
13551369
});
13561370

1357-
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | "));
1358-
builder.add_rule("root", "\"<|channel|>commentary to=functions.\" " + tool_call);
1371+
auto recipient_in_role = builder.add_rule("recipient_in_role",
1372+
"\"<|start|>assistant\"? \" to=functions.\" " +
1373+
string_join(tool_rules_recipient_in_role, " | ")
1374+
);
13591375

1376+
auto recipient_in_channel = builder.add_rule("recipient_in_channel",
1377+
channel + " \" to=functions.\" " +
1378+
string_join(tool_rules_recipient_in_channel, " | ")
1379+
);
1380+
1381+
builder.add_rule("root", recipient_in_role + " | " + recipient_in_channel);
1382+
1383+
// Trigger on tool calls that appear in the commentary channel
13601384
data.grammar_triggers.push_back({
13611385
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
1362-
"<\\|channel\\|>commentary to"
1386+
"<\\|channel\\|>(commentary|analysis) to"
1387+
});
1388+
1389+
// Trigger tool calls that appear in the role section, either at the
1390+
// start or in the middle.
1391+
data.grammar_triggers.push_back({
1392+
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
1393+
"^ to"
1394+
});
1395+
1396+
data.grammar_triggers.push_back({
1397+
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
1398+
"<\\|start\\|>assistant to"
13631399
});
13641400

13651401
data.preserved_tokens = {
@@ -1377,12 +1413,12 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
13771413
static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) {
13781414
static const common_regex message_regex("<\\|message\\|>");
13791415
static const common_regex channel_regex("<\\|channel\\|>(final|analysis|commentary)");
1416+
static const common_regex tool_call_channel_regex("<\\|channel\\|>(commentary|analysis)");
13801417
static const common_regex start_regex("<\\|start\\|>assistant");
13811418
static const common_regex end_regex("<\\|end\\|>");
13821419
static const common_regex to_regex(" to=");
1383-
static const common_regex user_tool_call_regex(
1384-
"functions\\.([a-zA-Z_][a-zA-Z0-9_]*)\\s*(?:(?:<\\|constrain\\|>)?([a-zA-Z]+))?\\s*<\\|message\\|>"
1385-
);
1420+
static const common_regex function_regex("functions\\.([a-zA-Z_][a-zA-Z0-9_]*)");
1421+
static const common_regex user_tool_call_regex("(?: <\\|constrain\\|>([a-zA-Z]+))?<\\|message\\|>");
13861422
static const common_regex builtin_tool_call_regex("(?:browser|python)[\\s\\S]*<\\|message\\|>");
13871423

13881424
// Save the start of the message so we can roll back when we encounter a tool call and parse_tool_calls == false.
@@ -1410,40 +1446,51 @@ static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) {
14101446
return false;
14111447
};
14121448

1413-
auto tool_call = [&]() {
1449+
auto tool_call = [&](bool recipient_in_role) {
14141450
if (!builder.syntax().parse_tool_calls) {
14151451
// Move back to the start and consume up to the next message
14161452
builder.move_to(message_start_pos);
14171453
builder.add_content(consume_until_next(message_start_pos + 1));
14181454
return;
14191455
}
14201456

1421-
if (auto res = builder.try_consume_regex(user_tool_call_regex)) {
1457+
if (auto res = builder.try_consume_regex(function_regex)) {
14221458
auto name = builder.str(res->groups[1]);
1423-
if (auto args = builder.try_consume_json_with_dumped_args({{}})) {
1424-
if (!builder.add_tool_call(name, "", args->value) || args->is_partial) {
1425-
throw common_chat_msg_partial_exception("incomplete tool call");
1459+
1460+
if (recipient_in_role) {
1461+
if (!builder.try_consume_regex(tool_call_channel_regex)) {
1462+
throw common_chat_msg_parse_exception("expected <|channel|>(commentary|analysis), got: " + consume_until_next());
1463+
}
1464+
}
1465+
1466+
if (builder.try_consume_regex(user_tool_call_regex)) {
1467+
if (auto args = builder.try_consume_json_with_dumped_args({{}})) {
1468+
if (!builder.add_tool_call(name, "", args->value) || args->is_partial) {
1469+
throw common_chat_msg_partial_exception("incomplete tool call");
1470+
}
14261471
}
1472+
} else {
1473+
throw common_chat_msg_parse_exception("expected function args, got: " + consume_until_next());
14271474
}
14281475
} else if (builder.try_consume_regex(builtin_tool_call_regex)) {
14291476
builder.consume_rest();
14301477
LOG_ERR("builtin tool calls not implemented\n");
14311478
} else {
1432-
throw common_chat_msg_parse_exception("expected function call, got: " + consume_until_next());
1479+
throw common_chat_msg_parse_exception("expected function name, got: " + consume_until_next());
14331480
}
14341481
};
14351482

14361483
auto commentary = [&]() {
14371484
if (builder.try_consume_regex(to_regex)) {
1438-
tool_call();
1485+
tool_call(false);
14391486
} else if (!try_consume_message()) {
14401487
throw common_chat_msg_parse_exception("expected: \" to=\" or <|message|>, got: " + consume_until_next());
14411488
}
14421489
};
14431490

14441491
auto analysis = [&]() {
14451492
if (builder.try_consume_regex(to_regex)) {
1446-
tool_call(); // built-in tools can be called in the analysis channel
1493+
tool_call(false); // built-in tools can be called in the analysis channel
14471494
} else if (builder.try_consume_regex(message_regex)) {
14481495
// Defer reasoning parsing to builder
14491496
builder.move_to(channel_start_pos);
@@ -1458,26 +1505,34 @@ static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) {
14581505
}
14591506
};
14601507

1461-
auto channel = [&]() {
1508+
auto channel = [&](const common_chat_msg_parser::find_regex_result & match) {
1509+
auto type = builder.str(match.groups[1]);
1510+
if (type == "analysis") {
1511+
analysis();
1512+
} else if (type == "commentary") {
1513+
commentary();
1514+
} else if (type == "final") {
1515+
if (!try_consume_message()) {
1516+
throw common_chat_msg_parse_exception("expected: <|message|>, got: " + consume_until_next());
1517+
}
1518+
} else {
1519+
throw common_chat_msg_parse_exception("expected one of: [analysis, commentary, final], got: " + consume_until_next());
1520+
}
1521+
};
1522+
1523+
auto message = [&]() {
14621524
if (auto res = builder.try_consume_regex(channel_regex)) {
14631525
channel_start_pos = res->groups[0].begin;
1464-
auto type = builder.str(res->groups[1]);
1465-
if (type == "analysis") {
1466-
analysis();
1467-
} else if (type == "commentary") {
1468-
commentary();
1469-
} else if (type == "final") {
1470-
if (!try_consume_message()) {
1471-
throw common_chat_msg_parse_exception("expected: <|message|>, got: " + consume_until_next());
1472-
}
1473-
}
1526+
channel(*res);
1527+
} else if (builder.try_consume_regex(to_regex)) {
1528+
tool_call(true);
14741529
} else {
1475-
throw common_chat_msg_parse_exception("expected: <|channel|>, got: " + consume_until_next());
1530+
throw common_chat_msg_parse_exception("expected: <|channel|> or \" to\", got: " + consume_until_next());
14761531
}
14771532
};
14781533

14791534
try {
1480-
channel();
1535+
message();
14811536
} catch (const common_chat_msg_parse_exception & e) {
14821537
LOG_DBG("Parse error: %s\n", e.what());
14831538
}
@@ -1486,7 +1541,7 @@ static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) {
14861541
while (auto res = builder.try_consume_regex(start_regex)) {
14871542
message_start_pos = res->groups[0].begin;
14881543
try {
1489-
channel();
1544+
message();
14901545
} catch (const common_chat_msg_parse_exception & e) {
14911546
LOG_DBG("Parse error: %s\n", e.what());
14921547
}

tests/test-chat.cpp

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1615,7 +1615,7 @@ static void test_template_output_parsers() {
16151615
assert_msg_equals(simple_assist_msg("", "I'm\nthinking", "special_function", "{\"arg1"),
16161616
common_chat_parse(
16171617
"<|channel|>analysis<|message|>I'm\nthinking<|end|>"
1618-
"<|start|>assistant<|channel|>commentary to=functions.special_function <|message|>{\"arg1",
1618+
"<|start|>assistant<|channel|>commentary to=functions.special_function<|message|>{\"arg1",
16191619
/* is_partial= */ true,
16201620
{
16211621
/* .format = */ COMMON_CHAT_FORMAT_GPT_OSS,
@@ -1630,6 +1630,15 @@ static void test_template_output_parsers() {
16301630
/* .format = */ COMMON_CHAT_FORMAT_GPT_OSS,
16311631
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
16321632
}));
1633+
assert_msg_equals(simple_assist_msg("", "I'm\nthinking", "special_function", "{\"arg1\": 1}"),
1634+
common_chat_parse(
1635+
"<|channel|>analysis<|message|>I'm\nthinking<|end|>"
1636+
"<|start|>assistant<|channel|>analysis to=functions.special_function <|constrain|>json<|message|>{\"arg1\": 1}",
1637+
/* is_partial= */ false,
1638+
{
1639+
/* .format = */ COMMON_CHAT_FORMAT_GPT_OSS,
1640+
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
1641+
}));
16331642
assert_msg_equals(simple_assist_msg("Hello, world!\nWhat's up?", "I'm\nthinking"),
16341643
common_chat_parse(
16351644
"<|channel|>analysis<|message|>I'm\nthinking<|end|>"
@@ -1666,11 +1675,11 @@ static void test_template_output_parsers() {
16661675
}));
16671676
assert_msg_equals(
16681677
simple_assist_msg(
1669-
"<|start|>assistant<|channel|>commentary to=functions.special_function <|message|>{\"arg1",
1678+
"<|start|>assistant<|channel|>commentary to=functions.special_function<|message|>{\"arg1",
16701679
"I'm\nthinking"),
16711680
common_chat_parse(
16721681
"<|channel|>analysis<|message|>I'm\nthinking<|end|>"
1673-
"<|start|>assistant<|channel|>commentary to=functions.special_function <|message|>{\"arg1",
1682+
"<|start|>assistant<|channel|>commentary to=functions.special_function<|message|>{\"arg1",
16741683
/* is_partial= */ true,
16751684
{
16761685
/* .format = */ COMMON_CHAT_FORMAT_GPT_OSS,
@@ -1732,6 +1741,33 @@ static void test_template_output_parsers() {
17321741
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
17331742
/* .reasoning_in_content = */ true,
17341743
}));
1744+
1745+
// Test tool calling in role header
1746+
assert_msg_equals(simple_assist_msg("", "", "special_function", "{\"arg1\": 1}"),
1747+
common_chat_parse(
1748+
" to=functions.special_function<|channel|>commentary <|constrain|>json<|message|>{\"arg1\": 1}",
1749+
/* is_partial= */ false,
1750+
{
1751+
/* .format = */ COMMON_CHAT_FORMAT_GPT_OSS,
1752+
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
1753+
}));
1754+
assert_msg_equals(simple_assist_msg("", "", "special_function", "{\"arg1\": 1}"),
1755+
common_chat_parse(
1756+
" to=functions.special_function<|channel|>analysis <|constrain|>json<|message|>{\"arg1\": 1}",
1757+
/* is_partial= */ false,
1758+
{
1759+
/* .format = */ COMMON_CHAT_FORMAT_GPT_OSS,
1760+
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
1761+
}));
1762+
assert_msg_equals(simple_assist_msg("", "I'm\nthinking", "special_function", "{\"arg1\": 1}"),
1763+
common_chat_parse(
1764+
"<|channel|>analysis<|message|>I'm\nthinking<|end|>"
1765+
"<|start|>assistant to=functions.special_function<|channel|>analysis <|constrain|>json<|message|>{\"arg1\": 1}",
1766+
/* is_partial= */ false,
1767+
{
1768+
/* .format = */ COMMON_CHAT_FORMAT_GPT_OSS,
1769+
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
1770+
}));
17351771
}
17361772
}
17371773

0 commit comments

Comments
 (0)